autogluon.multimodal 1.2.1b20250303__tar.gz → 1.2.1b20250305__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/PKG-INFO +1 -1
  2. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/setup.py +3 -3
  3. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/__init__.py +10 -0
  4. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/constants.py +16 -5
  5. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/__init__.py +14 -2
  6. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/dataset.py +2 -2
  7. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/infer_types.py +16 -2
  8. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/label_encoder.py +3 -3
  9. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/utils → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/data}/nlpaug.py +4 -4
  10. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  11. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_categorical.py +35 -6
  12. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_document.py +59 -33
  13. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/data/process_image.py +388 -0
  14. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_label.py +7 -3
  15. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  16. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  17. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  18. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/data/process_ner.py +359 -0
  19. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_numerical.py +32 -5
  20. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  21. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_text.py +95 -58
  22. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/template_engine.py +7 -9
  23. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/templates.py +0 -2
  24. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/trivial_augmenter.py +2 -2
  25. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/utils/data.py → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/data/utils.py +257 -117
  26. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/learners/__init__.py +2 -1
  27. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/learners/base.py +189 -189
  28. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/learners/ensemble.py +748 -0
  29. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/learners/few_shot_svm.py +6 -15
  30. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/learners/matching.py +59 -84
  31. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/learners/ner.py +23 -22
  32. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/learners/object_detection.py +26 -21
  33. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  34. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/__init__.py +12 -3
  35. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/models/augmenter.py +175 -0
  36. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/categorical_mlp.py +13 -8
  37. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/clip.py +92 -18
  38. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/custom_transformer.py +75 -75
  39. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/document_transformer.py +23 -9
  40. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/ft_transformer.py +40 -35
  41. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/fusion/base.py +2 -4
  42. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  43. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  44. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  45. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/models/huggingface_text.py → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/models/hf_text.py +21 -2
  46. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/models/meta_transformer.py +336 -0
  47. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/mlp.py +6 -6
  48. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  49. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  50. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/ner_text.py +1 -8
  51. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/numerical_mlp.py +14 -8
  52. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/sam.py +12 -2
  53. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/t_few.py +21 -5
  54. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/timm_image.py +74 -32
  55. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/models/utils.py +1766 -0
  56. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/__init__.py +17 -0
  57. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim}/lit_distiller.py +2 -1
  58. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim}/lit_matcher.py +4 -10
  59. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim}/lit_mmdet.py +2 -10
  60. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim}/lit_module.py +139 -14
  61. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim}/lit_ner.py +3 -3
  62. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim}/lit_semantic_seg.py +1 -1
  63. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/__init__.py +14 -0
  64. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  65. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  66. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  67. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  68. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  69. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  70. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/losses/utils.py +313 -0
  71. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/lr/__init__.py +1 -0
  72. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/lr/utils.py +332 -0
  73. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/metrics/__init__.py +4 -0
  74. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  75. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  76. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  77. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/metrics/utils.py +359 -0
  78. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/utils.py +284 -0
  79. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/predictor.py +51 -12
  80. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/__init__.py +19 -45
  81. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/cache.py +23 -2
  82. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/checkpoint.py +58 -5
  83. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/config.py +127 -55
  84. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils/device.py +120 -0
  85. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/distillation.py +8 -8
  86. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/download.py +1 -1
  87. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils/env.py +22 -0
  88. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/export.py +3 -3
  89. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/hpo.py +5 -5
  90. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/inference.py +37 -4
  91. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils/install.py +91 -0
  92. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/load.py +52 -47
  93. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/log.py +6 -41
  94. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/matcher.py +3 -2
  95. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/onnx.py +0 -4
  96. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils/path.py +10 -0
  97. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils/precision.py +130 -0
  98. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils}/presets.py +259 -66
  99. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils}/problem_types.py +30 -1
  100. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/save.py +47 -29
  101. autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils/strategy.py +24 -0
  102. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/version.py +1 -1
  103. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon.multimodal.egg-info/PKG-INFO +1 -1
  104. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon.multimodal.egg-info/SOURCES.txt +40 -22
  105. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon.multimodal.egg-info/requires.txt +4 -4
  106. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/__init__.py +0 -8
  107. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/data/process_image.py +0 -353
  108. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/data/process_ner.py +0 -171
  109. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/data/utils.py +0 -615
  110. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/models/utils.py +0 -905
  111. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization/__init__.py +0 -16
  112. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization/losses.py +0 -394
  113. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization/utils.py +0 -1054
  114. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/utils/cloud_io.py +0 -80
  115. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/utils/environment.py +0 -395
  116. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/utils/metric.py +0 -500
  117. autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/utils/model.py +0 -558
  118. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/setup.cfg +0 -0
  119. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/cli/__init__.py +0 -0
  120. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/cli/prepare_detection_dataset.py +0 -0
  121. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/cli/voc2coco.py +0 -0
  122. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/__init__.py +0 -0
  123. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/__init__.py +0 -0
  124. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/__init__.py +0 -0
  125. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/coco_detection.py +0 -0
  126. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/default_runtime.py +0 -0
  127. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/dino/dino-4scale_r50_8xb2-12e_coco.py +0 -0
  128. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/dino/dino-5scale_swin-l_8xb2-12e_coco.py +0 -0
  129. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/dino/dino-5scale_swin-l_8xb2-36e_coco.py +0 -0
  130. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/dino/dino_swinl_tta.py +0 -0
  131. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/dino/dino_tta.py +0 -0
  132. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/faster_rcnn/__init__.py +0 -0
  133. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/faster_rcnn/faster_rcnn_r50_fpn.py +0 -0
  134. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/schedule_1x.py +0 -0
  135. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/voc/__init__.py +0 -0
  136. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/voc/faster_rcnn_r50_fpn_1x_voc0712.py +0 -0
  137. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/voc/voc0712.py +0 -0
  138. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/__init__.py +0 -0
  139. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/yolox_l_8xb8-300e_coco.py +0 -0
  140. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/yolox_m_8xb8-300e_coco.py +0 -0
  141. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/yolox_nano_8xb8-300e_coco.py +0 -0
  142. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/yolox_s_8xb8-300e_coco.py +0 -0
  143. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/yolox_tiny_8xb8-300e_coco.py +0 -0
  144. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/yolox_tta.py +0 -0
  145. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/configs/pretrain/detection/yolox/yolox_x_8xb8-300e_coco.py +0 -0
  146. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/collator.py +0 -0
  147. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/datamodule.py +0 -0
  148. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/dataset_mmlab/__init__.py +0 -0
  149. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/dataset_mmlab/multi_image_mix_dataset.py +0 -0
  150. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/mixup.py +0 -0
  151. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/process_mmlab/__init__.py +0 -0
  152. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/data/randaug.py +0 -0
  153. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/adaptation_layers.py +0 -0
  154. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/custom_hf_models/modeling_sam_for_conv_lora.py +0 -0
  155. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/fusion/__init__.py +0 -0
  156. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/models/mmdet_image.py +0 -0
  157. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim}/deepspeed.py +0 -0
  158. /autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization/lr_scheduler.py → /autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/lr/lr_schedulers.py +0 -0
  159. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal/optimization → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/optim/metrics}/semantic_seg_metrics.py +0 -0
  160. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/colormap.py +0 -0
  161. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/label_studio.py +0 -0
  162. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/misc.py +0 -0
  163. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/mmcv.py +0 -0
  164. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/object_detection.py +0 -0
  165. {autogluon.multimodal-1.2.1b20250303/src/autogluon/multimodal → autogluon.multimodal-1.2.1b20250305/src/autogluon/multimodal/utils}/registry.py +0 -0
  166. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon/multimodal/utils/visualizer.py +0 -0
  167. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon.multimodal.egg-info/dependency_links.txt +0 -0
  168. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon.multimodal.egg-info/namespace_packages.txt +0 -0
  169. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon.multimodal.egg-info/top_level.txt +0 -0
  170. {autogluon.multimodal-1.2.1b20250303 → autogluon.multimodal-1.2.1b20250305}/src/autogluon.multimodal.egg-info/zip-safe +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.multimodal
3
- Version: 1.2.1b20250303
3
+ Version: 1.2.1b20250305
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -61,7 +61,7 @@ install_requires = ag.get_dependency_version_ranges(install_requires)
61
61
 
62
62
  tests_require = [
63
63
  "ruff",
64
- "datasets>=2.10.0,<2.15.0",
64
+ "datasets>=2.16.0,<2.20.0",
65
65
  "onnx>=1.13.0,<1.16.2;platform_system=='Windows'", # cap at 1.16.1 for issue https://github.com/onnx/onnx/issues/6267
66
66
  "onnx>=1.13.0,<1.18.0;platform_system!='Windows'",
67
67
  "onnxruntime>=1.17.0,<1.20.0", # install for gpu system due to https://github.com/autogluon/autogluon/issues/3804
@@ -78,8 +78,8 @@ if __name__ == "__main__":
78
78
  setup_args["package_data"]["autogluon.multimodal"] = [
79
79
  "configs/data/*.yaml",
80
80
  "configs/model/*.yaml",
81
- "configs/optimization/*.yaml",
82
- "configs/environment/*.yaml",
81
+ "configs/optim/*.yaml",
82
+ "configs/env/*.yaml",
83
83
  "configs/distiller/*.yaml",
84
84
  "configs/matcher/*.yaml",
85
85
  ]
@@ -0,0 +1,10 @@
1
+ from autogluon.common.utils.log_utils import _add_stream_handler
2
+
3
+ try:
4
+ from .version import __version__
5
+ except ImportError:
6
+ pass
7
+
8
+ from .predictor import MultiModalPredictor
9
+
10
+ _add_stream_handler()
@@ -59,12 +59,17 @@ SEMANTIC_SEGMENTATION_GT = "semantic_segmentation_gt"
59
59
 
60
60
  # Output keys
61
61
  LOGITS = "logits"
62
+ ORI_LOGITS = "ori_logits"
63
+ AUG_LOGITS = "aug_logits"
62
64
  TEMPLATE_LOGITS = "template_logits"
63
65
  LM_TARGET = "lm_target"
64
66
  LOSS = "loss"
65
67
  OUTPUT = "output"
66
68
  WEIGHT = "weight"
67
69
  FEATURES = "features"
70
+ MULTIMODAL_FEATURES = "multimodal_features" # used for the adapted multimodal features before the fusion module
71
+ MULTIMODAL_FEATURES_PRE_AUG = "multimodal_features_pre_aug"
72
+ MULTIMODAL_FEATURES_POST_AUG = "multimodal_features_post_aug"
68
73
  RAW_FEATURES = "raw_features"
69
74
  MASKS = "masks"
70
75
  PROBABILITY = "probability"
@@ -73,6 +78,8 @@ BBOX = "bbox"
73
78
  ROIS = "rois"
74
79
  SCORE = "score"
75
80
  LOGIT_SCALE = "logit_scale"
81
+ VAE_MEAN = "vae_mean"
82
+ VAE_VAR = "vae_var"
76
83
 
77
84
  # Loss
78
85
  MOE_LOSS = "moe_loss"
@@ -142,6 +149,7 @@ FM = "fm"
142
149
  MAE = "mae"
143
150
  BER = "ber"
144
151
  IOU = "iou"
152
+ COVERAGE = "coverage"
145
153
  RETRIEVAL_METRICS = [NDCG, PRECISION, RECALL, MRR]
146
154
  METRIC_MODE_MAP = {
147
155
  ACC: MAX,
@@ -168,6 +176,7 @@ METRIC_MODE_MAP = {
168
176
  SM: MAX,
169
177
  IOU: MAX,
170
178
  BER: MIN,
179
+ COVERAGE: MAX,
171
180
  }
172
181
 
173
182
  MATCHING_METRICS = {
@@ -179,7 +188,7 @@ MATCHING_METRICS_WITHOUT_PROBLEM_TYPE = [RECALL, NDCG]
179
188
 
180
189
  EVALUATION_METRICS = {
181
190
  # Use evaluation metrics from METRICS for these types
182
- BINARY: METRICS[BINARY].keys(),
191
+ BINARY: list(METRICS[BINARY].keys()) + [COVERAGE],
183
192
  MULTICLASS: METRICS[MULTICLASS].keys(),
184
193
  REGRESSION: METRICS[REGRESSION].keys(),
185
194
  OBJECT_DETECTION: DETECTION_METRICS,
@@ -197,6 +206,7 @@ VALIDATION_METRICS = {
197
206
  # Training status
198
207
  TRAIN = "train"
199
208
  VALIDATE = "validate"
209
+ VAL = "val"
200
210
  TEST = "test"
201
211
  PREDICT = "predict"
202
212
 
@@ -217,11 +227,11 @@ Y_TRUE = "y_true"
217
227
  # Configuration keys
218
228
  MODEL = "model"
219
229
  DATA = "data"
220
- OPTIMIZATION = "optimization"
221
- ENVIRONMENT = "environment"
230
+ OPTIM = "optim"
231
+ ENV = "env"
222
232
  DISTILLER = "distiller"
223
233
  MATCHER = "matcher"
224
- VALID_CONFIG_KEYS = [MODEL, DATA, OPTIMIZATION, ENVIRONMENT, DISTILLER, MATCHER]
234
+ VALID_CONFIG_KEYS = [MODEL, DATA, OPTIM, ENV, DISTILLER, MATCHER]
225
235
 
226
236
  # Image normalization mean and std. This is only to normalize images for the CLIP model.
227
237
  CLIP_IMAGE_MEAN = (0.48145466, 0.4578275, 0.40821073)
@@ -275,7 +285,7 @@ PEFT_STRATEGIES = list(set(PEFT_ADDITIVE_STRATEGIES) | set(PEFT_NON_ADDITIVE_STR
275
285
  # DeepSpeed constants
276
286
  DEEPSPEED_OFFLOADING = "deepspeed_stage_3_offload"
277
287
  DEEPSPEED_STRATEGY = "deepspeed"
278
- DEEPSPEED_MODULE = "autogluon.multimodal.optimization.deepspeed"
288
+ DEEPSPEED_MODULE = "autogluon.multimodal.optim.deepspeed"
279
289
  DEEPSPEED_MIN_PL_VERSION = "1.7.1"
280
290
 
281
291
  # registered model keys. TODO: document how to add new models.
@@ -298,6 +308,7 @@ DOCUMENT_TRANSFORMER = "document_transformer"
298
308
  HF_MODELS = (HF_TEXT, T_FEW, CLIP, NER_TEXT, DOCUMENT_TRANSFORMER)
299
309
  MMLAB_MODELS = (MMDET_IMAGE, MMOCR_TEXT_DET, MMOCR_TEXT_RECOG)
300
310
  SAM = "sam"
311
+ META_TRANSFORMER = "meta_transformer"
301
312
 
302
313
  # matcher loss type
303
314
  CONTRASTIVE_LOSS = "contrastive_loss"
@@ -1,4 +1,3 @@
1
- from . import collator, infer_types, randaug, utils
2
1
  from .datamodule import BaseDataModule
3
2
  from .dataset import BaseDataset
4
3
  from .dataset_mmlab import MultiImageMixDataset
@@ -9,8 +8,9 @@ from .infer_types import (
9
8
  infer_rois_column_type,
10
9
  is_image_column,
11
10
  )
12
- from .label_encoder import CustomLabelEncoder, NerLabelEncoder
13
11
  from .mixup import MixupModule
12
+ from .infer_types import infer_column_types, infer_output_shape, infer_problem_type, is_image_column, infer_ner_column_type
13
+ from .label_encoder import CustomLabelEncoder, NerLabelEncoder
14
14
  from .preprocess_dataframe import MultiModalFeaturePreprocessor
15
15
  from .process_categorical import CategoricalProcessor
16
16
  from .process_document import DocumentProcessor
@@ -21,3 +21,15 @@ from .process_ner import NerProcessor
21
21
  from .process_numerical import NumericalProcessor
22
22
  from .process_semantic_seg_img import SemanticSegImageProcessor
23
23
  from .process_text import TextProcessor
24
+ from .utils import (
25
+ create_data_processor,
26
+ create_fusion_data_processors,
27
+ data_to_df,
28
+ get_detected_data_types,
29
+ get_mixup,
30
+ infer_dtypes_by_model_names,
31
+ infer_scarcity_mode_by_data_size,
32
+ init_df_preprocessor,
33
+ split_train_tuning_data,
34
+ turn_on_off_feature_column_info,
35
+ )
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
4
4
  import pandas as pd
5
5
  import torch
6
6
 
7
- from ..constants import AUTOMM, GET_ITEM_ERROR_RETRY
7
+ from ..constants import GET_ITEM_ERROR_RETRY
8
8
  from .preprocess_dataframe import MultiModalFeaturePreprocessor
9
9
  from .utils import apply_data_processor, apply_df_preprocessor, get_per_sample_features
10
10
 
@@ -100,7 +100,7 @@ class BaseDataset(torch.utils.data.Dataset):
100
100
  per_ret = apply_data_processor(
101
101
  per_sample_features=per_sample_features,
102
102
  data_processors=per_processors_group,
103
- feature_modalities=getattr(self, f"modality_types_{group_id}"),
103
+ data_types=getattr(self, f"modality_types_{group_id}"),
104
104
  is_training=self.is_training,
105
105
  )
106
106
  ret.update(per_ret)
@@ -19,7 +19,6 @@ from ..constants import (
19
19
  DOCUMENT_IMAGE,
20
20
  DOCUMENT_PDF,
21
21
  IDENTIFIER,
22
- IMAGE,
23
22
  IMAGE_BASE64_STR,
24
23
  IMAGE_BYTEARRAY,
25
24
  IMAGE_PATH,
@@ -37,7 +36,6 @@ from ..constants import (
37
36
  TEXT,
38
37
  TEXT_NER,
39
38
  )
40
- from .utils import is_rois_input
41
39
 
42
40
  logger = logging.getLogger(__name__)
43
41
 
@@ -114,6 +112,22 @@ def is_categorical_column(
114
112
  return False
115
113
 
116
114
 
115
+ def is_rois_input(sample):
116
+ """
117
+ check if a sample is rois for object detection
118
+
119
+ Parameters
120
+ ----------
121
+ sample
122
+ The sampled data.
123
+
124
+ Returns
125
+ -------
126
+ bool, whether a sample is rois for object detection
127
+ """
128
+ return isinstance(sample, list) and len(sample) and isinstance(sample[0], list) and len(sample[0]) == 5
129
+
130
+
117
131
  def is_rois_column(data: pd.Series) -> bool:
118
132
  """
119
133
  Identify if a column is one rois column.
@@ -9,7 +9,7 @@ import pandas as pd
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
  from sklearn.preprocessing import LabelEncoder
11
11
 
12
- from ..constants import AUTOMM, END_OFFSET, ENTITY_GROUP, NER_ANNOTATION, PROBABILITY, START_OFFSET
12
+ from ..constants import END_OFFSET, ENTITY_GROUP, PROBABILITY, START_OFFSET
13
13
 
14
14
  logger = logging.getLogger(__name__)
15
15
 
@@ -137,12 +137,12 @@ class NerLabelEncoder:
137
137
  transformed_y
138
138
  A list of word level annotations.
139
139
  """
140
- from .utils import process_ner_annotations
140
+ from .process_ner import NerProcessor
141
141
 
142
142
  all_annotations, _ = self.extract_ner_annotations(y)
143
143
  transformed_y = []
144
144
  for annotation, text_snippet in zip(all_annotations, x.items()):
145
- word_label, _, _, _ = process_ner_annotations(
145
+ word_label, _, _, _ = NerProcessor.process_ner_annotations(
146
146
  annotation, text_snippet[-1], self.entity_map, tokenizer, is_eval=True
147
147
  )
148
148
  word_label_invers = []
@@ -78,14 +78,14 @@ class InsertPunctuation(Augmenter):
78
78
  new = " ".join(new)
79
79
  return new
80
80
 
81
- @classmethod
82
- def clean(cls, data):
81
+ @staticmethod
82
+ def clean(data):
83
83
  if isinstance(data, list):
84
84
  return [d.strip() if d else d for d in data]
85
85
  return data.strip()
86
86
 
87
- @classmethod
88
- def is_duplicate(cls, dataset, data):
87
+ @staticmethod
88
+ def is_duplicate(dataset, data):
89
89
  for d in dataset:
90
90
  if d == data:
91
91
  return True
@@ -14,17 +14,14 @@ from sklearn.preprocessing import MinMaxScaler, StandardScaler
14
14
  from autogluon.features import CategoryFeatureGenerator
15
15
 
16
16
  from ..constants import (
17
- AUTOMM,
18
17
  CATEGORICAL,
19
18
  DOCUMENT,
20
- DOCUMENT_IMAGE,
21
19
  IDENTIFIER,
22
20
  IMAGE,
23
21
  IMAGE_BASE64_STR,
24
22
  IMAGE_BYTEARRAY,
25
23
  IMAGE_PATH,
26
24
  LABEL,
27
- NER,
28
25
  NER_ANNOTATION,
29
26
  NULL,
30
27
  NUMERICAL,
@@ -73,19 +70,17 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
73
70
 
74
71
  if label_column:
75
72
  if label_generator is None:
76
- self._label_generator = CustomLabelEncoder(
77
- positive_class=OmegaConf.select(config, "pos_label", default=None)
78
- )
73
+ self._label_generator = CustomLabelEncoder(positive_class=config.pos_label)
79
74
  else:
80
75
  self._label_generator = label_generator
81
76
 
82
77
  # Scaler used for numerical labels
83
- numerical_label_preprocessing = OmegaConf.select(config, "label.numerical_label_preprocessing")
78
+ numerical_label_preprocessing = config.label.numerical_preprocessing
84
79
  if numerical_label_preprocessing == "minmaxscaler":
85
80
  self._label_scaler = MinMaxScaler()
86
81
  elif numerical_label_preprocessing == "standardscaler":
87
82
  self._label_scaler = StandardScaler()
88
- elif numerical_label_preprocessing is None or numerical_label_preprocessing.lower() == "none":
83
+ elif numerical_label_preprocessing is None:
89
84
  self._label_scaler = StandardScaler(with_mean=False, with_std=False)
90
85
  else:
91
86
  raise ValueError(
@@ -135,8 +130,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
135
130
  # Some columns will be ignored
136
131
  self._ignore_columns_set = set()
137
132
  self._text_feature_names = []
138
- self._categorical_feature_names = []
139
- self._categorical_num_categories = []
133
+ self._categorical_num_categories = dict()
140
134
  self._numerical_feature_names = []
141
135
  self._image_feature_names = []
142
136
  self._rois_feature_names = []
@@ -154,10 +148,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
154
148
 
155
149
  @property
156
150
  def image_path_names(self):
157
- if hasattr(self, "_image_path_names"):
158
- return self._image_path_names
159
- else:
160
- return [col_name for col_name in self._image_feature_names if self._column_types[col_name] == IMAGE_PATH]
151
+ return [col_name for col_name in self._image_feature_names if self._column_types[col_name] == IMAGE_PATH]
161
152
 
162
153
  @property
163
154
  def rois_feature_names(self):
@@ -173,7 +164,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
173
164
 
174
165
  @property
175
166
  def image_feature_names(self):
176
- return self._image_path_names if hasattr(self, "_image_path_names") else self._image_feature_names
167
+ return self._image_feature_names
177
168
 
178
169
  @property
179
170
  def text_feature_names(self):
@@ -181,12 +172,21 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
181
172
 
182
173
  @property
183
174
  def categorical_feature_names(self):
184
- return self._categorical_feature_names
175
+ return list(self.categorical_num_categories.keys())
185
176
 
186
177
  @property
187
178
  def numerical_feature_names(self):
188
179
  return self._numerical_feature_names
189
180
 
181
+ @property
182
+ def numerical_fill_values(self):
183
+ ret = dict()
184
+ for col_name in self._numerical_feature_names:
185
+ generator = self._feature_generators[col_name]
186
+ ret[col_name] = generator.transform(np.full([1, 1], np.nan))[:, 0][0]
187
+
188
+ return ret
189
+
190
190
  @property
191
191
  def document_feature_names(self):
192
192
  # Added for backward compatibility.
@@ -216,17 +216,12 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
216
216
 
217
217
  @property
218
218
  def required_feature_names(self):
219
- image_feature_names = (
220
- self._image_path_names if hasattr(self, "_image_path_names") else self._image_feature_names
221
- )
222
- rois_feature_names = self._rois_feature_names if hasattr(self, "_rois_feature_names") else []
223
-
224
219
  return (
225
- image_feature_names
220
+ self._image_feature_names
226
221
  + self._text_feature_names
227
222
  + self._numerical_feature_names
228
- + self._categorical_feature_names
229
- + rois_feature_names
223
+ + self.categorical_feature_names
224
+ + self._rois_feature_names
230
225
  )
231
226
 
232
227
  @property
@@ -268,16 +263,13 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
268
263
 
269
264
  def get_column_names(self, modality: str):
270
265
  if modality.startswith(IMAGE):
271
- if hasattr(self, "_image_path_names"):
272
- return self._image_path_names
273
- else:
274
- return self._image_feature_names
266
+ return self._image_feature_names
275
267
  elif modality == ROIS:
276
268
  return self._rois_feature_names
277
269
  elif modality == TEXT:
278
270
  return self._text_feature_names
279
271
  elif modality == CATEGORICAL:
280
- return self._categorical_feature_names
272
+ return self.categorical_feature_names
281
273
  elif modality == NUMERICAL:
282
274
  return self._numerical_feature_names
283
275
  elif modality.startswith(DOCUMENT):
@@ -344,8 +336,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
344
336
  continue
345
337
  num_categories = len(generator.category_map[col_name])
346
338
  # Add one unknown category
347
- self._categorical_num_categories.append(num_categories + 1)
348
- self._categorical_feature_names.append(col_name)
339
+ self._categorical_num_categories[col_name] = num_categories + 1
349
340
  elif col_type == NUMERICAL:
350
341
  processed_data = pd.to_numeric(col_value)
351
342
  if len(processed_data.unique()) == 1:
@@ -392,7 +383,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
392
383
  elif self.label_type == NUMERICAL:
393
384
  y = pd.to_numeric(y).to_numpy()
394
385
  self._label_scaler.fit(np.expand_dims(y, axis=-1))
395
- elif self.label_type == ROIS or self.label_type == SEMANTIC_SEGMENTATION_GT:
386
+ elif self.label_type in [ROIS, SEMANTIC_SEGMENTATION_GT]:
396
387
  pass # Do nothing. TODO: Shall we call fit here?
397
388
  elif self.label_type == NER_ANNOTATION:
398
389
  # If there are ner annotations and text columns but no NER feature columns,
@@ -426,6 +417,24 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
426
417
  if y is not None:
427
418
  self._fit_y(y=y, X=X)
428
419
 
420
+ @staticmethod
421
+ def convert_categorical_to_text(col_value: pd.Series, template: str, col_name: str):
422
+ # TODO: do we need to consider whether categorical values are valid text?
423
+ col_value = col_value.astype("object")
424
+ if template == "direct":
425
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele))
426
+ elif template == "list":
427
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else col_name + ": " + str(ele))
428
+ elif template == "text":
429
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else col_name + " is " + str(ele))
430
+ elif template == "latex":
431
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele) + " & ")
432
+ else:
433
+ raise ValueError(
434
+ f"Unsupported template {template} for converting categorical data into text. Select one from: ['direct', 'list', 'text', 'latex']."
435
+ )
436
+ return processed_data
437
+
429
438
  def transform_text(
430
439
  self,
431
440
  df: pd.DataFrame,
@@ -455,10 +464,15 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
455
464
  for col_name in self._text_feature_names:
456
465
  col_value = df[col_name]
457
466
  col_type = self._column_types[col_name]
458
- if col_type == TEXT or col_type == CATEGORICAL:
459
- # TODO: do we need to consider whether categorical values are valid text?
467
+ if col_type == TEXT:
460
468
  col_value = col_value.astype("object")
461
469
  processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele))
470
+ elif col_type == CATEGORICAL:
471
+ processed_data = self.convert_categorical_to_text(
472
+ col_value=col_value,
473
+ template=self._config.categorical.convert_to_text_template,
474
+ col_name=col_name,
475
+ )
462
476
  elif col_type == NUMERICAL:
463
477
  processed_data = pd.to_numeric(col_value).apply("{:.3f}".format)
464
478
  elif col_type == f"{TEXT}_{IDENTIFIER}":
@@ -710,7 +724,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
710
724
  self._fit_called or self._fit_x_called
711
725
  ), "You will need to first call preprocessor.fit before calling preprocessor.transform_categorical."
712
726
  categorical_features = {}
713
- for col_name, num_category in zip(self._categorical_feature_names, self._categorical_num_categories):
727
+ for col_name, num_category in self._categorical_num_categories.items():
714
728
  col_value = df[col_name]
715
729
  processed_data = col_value.astype("category")
716
730
  generator = self._feature_generators[col_name]
@@ -757,7 +771,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
757
771
  elif self.label_type == NUMERICAL:
758
772
  y = pd.to_numeric(y_df).to_numpy()
759
773
  y = self._label_scaler.transform(np.expand_dims(y, axis=-1))[:, 0].astype(np.float32)
760
- elif self.label_type == ROIS or self.label_type == SEMANTIC_SEGMENTATION_GT:
774
+ elif self.label_type in [ROIS, SEMANTIC_SEGMENTATION_GT]:
761
775
  y = y_df.to_list()
762
776
  elif self.label_type == NER_ANNOTATION:
763
777
  y = self._label_generator.transform(y_df)
@@ -866,8 +880,11 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
866
880
  ), "You will need to first call preprocessor.fit_y() before calling preprocessor.transform_prediction."
867
881
 
868
882
  if self.label_type == CATEGORICAL:
869
- assert y_pred.shape[1] >= 2
870
- y_pred = y_pred.argmax(axis=1)
883
+ assert len(y_pred.shape) <= 2
884
+ if len(y_pred.shape) == 2 and y_pred.shape[1] >= 2:
885
+ y_pred = y_pred.argmax(axis=1)
886
+ else:
887
+ y_pred = (y_pred > 0.5).astype(int)
871
888
  # Transform the predicted label back to the original space (e.g., string values)
872
889
  if inverse_categorical:
873
890
  y_pred = self._label_generator.inverse_transform(y_pred)
@@ -1,11 +1,14 @@
1
+ import logging
2
+ import random
1
3
  from typing import Any, Dict, List, Optional, Union
2
4
 
3
- import numpy as np
4
5
  from torch import nn
5
6
 
6
7
  from ..constants import CATEGORICAL, COLUMN
7
8
  from .collator import StackCollator, TupleCollator
8
9
 
10
+ logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  class CategoricalProcessor:
11
14
  """
@@ -18,6 +21,7 @@ class CategoricalProcessor:
18
21
  self,
19
22
  model: nn.Module,
20
23
  requires_column_info: bool = False,
24
+ dropout: Optional[float] = 0,
21
25
  ):
22
26
  """
23
27
  Parameters
@@ -27,8 +31,16 @@ class CategoricalProcessor:
27
31
  requires_column_info
28
32
  Whether to require feature column information in dataloader.
29
33
  """
34
+ logger.debug(f"initializing categorical processor for model {model.prefix}")
30
35
  self.prefix = model.prefix
31
36
  self.requires_column_info = requires_column_info
37
+ self.num_categories = model.num_categories
38
+ self.dropout = dropout
39
+ assert 0 <= self.dropout <= 1
40
+ if self.dropout > 0:
41
+ logger.debug(f"categorical value dropout probability: {self.dropout}")
42
+ fill_values = {k: v - 1 for k, v in self.num_categories.items()}
43
+ logger.debug(f"dropped values will be replaced by {fill_values}")
32
44
 
33
45
  @property
34
46
  def categorical_key(self):
@@ -60,6 +72,7 @@ class CategoricalProcessor:
60
72
  def process_one_sample(
61
73
  self,
62
74
  categorical_features: Dict[str, int],
75
+ is_training: bool,
63
76
  ) -> Dict:
64
77
  """
65
78
  Process one sample's categorical features. Assume the categorical features
@@ -69,6 +82,8 @@ class CategoricalProcessor:
69
82
  ----------
70
83
  categorical_features
71
84
  Categorical features of one sample.
85
+ is_training
86
+ Whether to do processing in the training mode.
72
87
 
73
88
  Returns
74
89
  -------
@@ -80,6 +95,17 @@ class CategoricalProcessor:
80
95
  for i, col_name in enumerate(categorical_features.keys()):
81
96
  ret[f"{self.categorical_column_prefix}_{col_name}"] = i
82
97
 
98
+ if is_training and self.dropout > 0:
99
+ categorical_features_copy = dict()
100
+ for k, v in categorical_features.items():
101
+ if random.uniform(0, 1) <= self.dropout:
102
+ categorical_features_copy[k] = self.num_categories[k] - 1
103
+ else:
104
+ categorical_features_copy[k] = v
105
+ categorical_features = categorical_features_copy
106
+
107
+ # make sure keys are in the same order
108
+ assert list(categorical_features.keys()) == list(self.num_categories.keys())
83
109
  ret[self.categorical_key] = list(categorical_features.values())
84
110
 
85
111
  return ret
@@ -87,7 +113,7 @@ class CategoricalProcessor:
87
113
  def __call__(
88
114
  self,
89
115
  categorical_features: Dict[str, int],
90
- feature_modalities: Dict[str, Union[int, float, list]],
116
+ sub_dtypes: Dict[str, str],
91
117
  is_training: bool,
92
118
  ) -> Dict:
93
119
  """
@@ -97,13 +123,16 @@ class CategoricalProcessor:
97
123
  ----------
98
124
  categorical_features
99
125
  Categorical features of one sample.
100
- feature_modalities
101
- The modality of the feature columns.
126
+ sub_dtypes
127
+ The sub data types of all categorical columns.
102
128
  is_training
103
- Whether to do processing in the training mode. This unused flag is for the API compatibility.
129
+ Whether to do processing in the training mode.
104
130
 
105
131
  Returns
106
132
  -------
107
133
  A dictionary containing one sample's processed categorical features.
108
134
  """
109
- return self.process_one_sample(categorical_features)
135
+ return self.process_one_sample(
136
+ categorical_features=categorical_features,
137
+ is_training=is_training,
138
+ )