autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,25 +1,23 @@
|
|
1
1
|
import ast
|
2
|
+
import codecs
|
2
3
|
import logging
|
3
4
|
import os
|
5
|
+
import random
|
4
6
|
import warnings
|
5
7
|
from copy import deepcopy
|
6
|
-
from typing import
|
8
|
+
from typing import Dict, List, Optional, Tuple, Union
|
7
9
|
|
8
10
|
import numpy as np
|
9
11
|
from numpy.typing import NDArray
|
10
12
|
from omegaconf import DictConfig
|
13
|
+
from text_unidecode import unidecode
|
11
14
|
from torch import nn
|
12
15
|
|
13
16
|
from ..constants import CHOICES_IDS, COLUMN, TEXT, TEXT_SEGMENT_IDS, TEXT_TOKEN_IDS, TEXT_VALID_LENGTH
|
17
|
+
from ..models.utils import get_pretrained_tokenizer
|
14
18
|
from .collator import PadCollator, StackCollator
|
15
19
|
from .template_engine import TemplateEngine
|
16
20
|
from .trivial_augmenter import TrivialAugment
|
17
|
-
from .utils import (
|
18
|
-
extract_value_from_config,
|
19
|
-
get_text_token_max_len,
|
20
|
-
normalize_txt,
|
21
|
-
register_encoding_decoding_error_handlers,
|
22
|
-
)
|
23
21
|
|
24
22
|
logger = logging.getLogger(__name__)
|
25
23
|
|
@@ -36,9 +34,7 @@ class TextProcessor:
|
|
36
34
|
def __init__(
|
37
35
|
self,
|
38
36
|
model: nn.Module,
|
39
|
-
max_len: Optional[int] = None,
|
40
37
|
insert_sep: Optional[bool] = True,
|
41
|
-
text_segment_num: Optional[int] = 1,
|
42
38
|
stochastic_chunk: Optional[bool] = False,
|
43
39
|
requires_column_info: bool = False,
|
44
40
|
text_detection_length: Optional[int] = None,
|
@@ -46,18 +42,15 @@ class TextProcessor:
|
|
46
42
|
train_augment_types: Optional[List[str]] = None,
|
47
43
|
template_config: Optional[DictConfig] = None,
|
48
44
|
normalize_text: Optional[bool] = False,
|
45
|
+
dropout: Optional[float] = 0,
|
49
46
|
):
|
50
47
|
"""
|
51
48
|
Parameters
|
52
49
|
----------
|
53
50
|
model
|
54
51
|
The model for which this processor would be created.
|
55
|
-
max_len
|
56
|
-
The maximum length of text tokens.
|
57
52
|
insert_sep
|
58
53
|
Whether to insert SEP tokens.
|
59
|
-
text_segment_num
|
60
|
-
The number of text segments.
|
61
54
|
stochastic_chunk
|
62
55
|
Whether to use stochastic chunking, which will randomly slice each individual text.
|
63
56
|
requires_column_info
|
@@ -75,6 +68,7 @@ class TextProcessor:
|
|
75
68
|
Examples of normalized texts can be found at
|
76
69
|
https://github.com/autogluon/autogluon/tree/master/examples/automm/kaggle_feedback_prize#15-a-few-examples-of-normalized-texts
|
77
70
|
"""
|
71
|
+
logger.debug(f"initializing text processor for model {model.prefix}")
|
78
72
|
self.prefix = model.prefix
|
79
73
|
self.requires_column_info = requires_column_info
|
80
74
|
self.tokenizer_name = model.tokenizer_name
|
@@ -86,38 +80,17 @@ class TextProcessor:
|
|
86
80
|
self.tokenizer.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
87
81
|
|
88
82
|
self.cls_token_id, self.sep_token_id, self.eos_token_id = self.get_special_tokens(tokenizer=self.tokenizer)
|
89
|
-
self.max_len =
|
90
|
-
provided_max_len=max_len,
|
91
|
-
config=model.config,
|
92
|
-
tokenizer=self.tokenizer,
|
93
|
-
checkpoint_name=model.checkpoint_name,
|
94
|
-
)
|
95
|
-
logger.debug(f"text max length: {self.max_len}")
|
83
|
+
self.max_len = model.max_text_len
|
96
84
|
self.insert_sep = insert_sep
|
97
85
|
self.eos_only = self.cls_token_id == self.sep_token_id == self.eos_token_id
|
98
|
-
|
99
|
-
extracted = extract_value_from_config(config=model.config.to_diff_dict(), keys=("type_vocab_size",))
|
100
|
-
if len(extracted) == 0:
|
101
|
-
default_segment_num = 1
|
102
|
-
elif len(extracted) == 1:
|
103
|
-
default_segment_num = extracted[0]
|
104
|
-
else:
|
105
|
-
raise ValueError(f" more than one type_vocab_size values are detected: {extracted}")
|
106
|
-
|
107
|
-
if default_segment_num <= 0:
|
108
|
-
default_segment_num = 1
|
109
|
-
|
110
|
-
if text_segment_num < default_segment_num:
|
111
|
-
warnings.warn(
|
112
|
-
f"provided text_segment_num: {text_segment_num} "
|
113
|
-
f"is smaller than {model.checkpoint_name}'s default: {default_segment_num}"
|
114
|
-
)
|
115
|
-
self.text_segment_num = min(text_segment_num, default_segment_num)
|
116
|
-
assert self.text_segment_num >= 1
|
117
|
-
logger.debug(f"text segment num: {self.text_segment_num}")
|
86
|
+
self.text_segment_num = model.text_segment_num
|
118
87
|
|
119
88
|
self.stochastic_chunk = stochastic_chunk
|
120
89
|
self.normalize_text = normalize_text
|
90
|
+
assert 0 <= dropout <= 1
|
91
|
+
if dropout > 0:
|
92
|
+
logger.debug(f"text dropout probability: {dropout}")
|
93
|
+
self.dropout = dropout
|
121
94
|
|
122
95
|
# construct augmentor
|
123
96
|
self.train_augment_types = train_augment_types
|
@@ -131,7 +104,7 @@ class TextProcessor:
|
|
131
104
|
self.template_engine = None
|
132
105
|
|
133
106
|
if self.normalize_text:
|
134
|
-
register_encoding_decoding_error_handlers()
|
107
|
+
self.register_encoding_decoding_error_handlers()
|
135
108
|
|
136
109
|
@property
|
137
110
|
def text_token_ids_key(self):
|
@@ -243,14 +216,9 @@ class TextProcessor:
|
|
243
216
|
segment_ids.append(seg)
|
244
217
|
seg = (seg + 1) % self.text_segment_num
|
245
218
|
|
246
|
-
if
|
247
|
-
|
248
|
-
|
249
|
-
segment_ids.append(seg)
|
250
|
-
else: # backward compatibility
|
251
|
-
if token_ids[-1] != self.sep_token_id:
|
252
|
-
token_ids.append(self.sep_token_id)
|
253
|
-
segment_ids.append(seg)
|
219
|
+
if token_ids[-1] != self.eos_token_id:
|
220
|
+
token_ids.append(self.eos_token_id)
|
221
|
+
segment_ids.append(seg)
|
254
222
|
|
255
223
|
ret.update(
|
256
224
|
{
|
@@ -298,7 +266,9 @@ class TextProcessor:
|
|
298
266
|
|
299
267
|
for col_name, col_text in text.items():
|
300
268
|
if is_training:
|
301
|
-
if self.
|
269
|
+
if self.dropout > 0 and random.uniform(0, 1) <= self.dropout:
|
270
|
+
col_text = ""
|
271
|
+
elif self.train_augmenter is not None:
|
302
272
|
# naive way to detect categorical/numerical text:
|
303
273
|
if len(col_text.split(" ")) >= self.text_detection_length:
|
304
274
|
col_text = self.train_augmenter(col_text)
|
@@ -446,8 +416,8 @@ class TextProcessor:
|
|
446
416
|
|
447
417
|
def __call__(
|
448
418
|
self,
|
449
|
-
|
450
|
-
|
419
|
+
text: Dict[str, str],
|
420
|
+
sub_dtypes: Dict[str, str],
|
451
421
|
is_training: bool,
|
452
422
|
) -> Dict:
|
453
423
|
"""
|
@@ -455,10 +425,10 @@ class TextProcessor:
|
|
455
425
|
|
456
426
|
Parameters
|
457
427
|
----------
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
The
|
428
|
+
text
|
429
|
+
Text of one sample.
|
430
|
+
sub_dtypes
|
431
|
+
The sub data types of all text columns.
|
462
432
|
is_training
|
463
433
|
Whether to do processing in the training mode.
|
464
434
|
|
@@ -467,9 +437,9 @@ class TextProcessor:
|
|
467
437
|
A dictionary containing one sample's text tokens, valid length, and segment ids.
|
468
438
|
"""
|
469
439
|
if self.normalize_text:
|
470
|
-
|
440
|
+
text = {col_name: self.normalize_txt(col_text) for col_name, col_text in text.items()}
|
471
441
|
|
472
|
-
return self.build_one_token_sequence_from_text(
|
442
|
+
return self.build_one_token_sequence_from_text(text=text, is_training=is_training)
|
473
443
|
|
474
444
|
def __deepcopy__(self, memo):
|
475
445
|
cls = self.__class__
|
@@ -495,3 +465,70 @@ class TextProcessor:
|
|
495
465
|
self.train_augmenter = self.construct_text_augmenter(
|
496
466
|
state["text_trivial_aug_maxscale"], state["train_augment_types"]
|
497
467
|
)
|
468
|
+
|
469
|
+
def save_tokenizer(
|
470
|
+
self,
|
471
|
+
path: str,
|
472
|
+
):
|
473
|
+
"""
|
474
|
+
Save the text tokenizer and record its relative paths, e.g, hf_text.
|
475
|
+
|
476
|
+
Parameters
|
477
|
+
----------
|
478
|
+
path
|
479
|
+
The root path of saving.
|
480
|
+
|
481
|
+
"""
|
482
|
+
save_path = os.path.join(path, self.prefix)
|
483
|
+
self.tokenizer.save_pretrained(save_path)
|
484
|
+
self.tokenizer = self.prefix
|
485
|
+
|
486
|
+
def load_tokenizer(
|
487
|
+
self,
|
488
|
+
path: str,
|
489
|
+
):
|
490
|
+
"""
|
491
|
+
Load saved text tokenizers. If text/ner processors already have tokenizers,
|
492
|
+
then do nothing.
|
493
|
+
|
494
|
+
Parameters
|
495
|
+
----------
|
496
|
+
path
|
497
|
+
The root path of loading.
|
498
|
+
|
499
|
+
Returns
|
500
|
+
-------
|
501
|
+
A list of text/ner processors with tokenizers loaded.
|
502
|
+
"""
|
503
|
+
if isinstance(self.tokenizer, str):
|
504
|
+
load_path = os.path.join(path, self.tokenizer)
|
505
|
+
self.tokenizer = get_pretrained_tokenizer(
|
506
|
+
tokenizer_name=self.tokenizer_name,
|
507
|
+
checkpoint_name=load_path,
|
508
|
+
)
|
509
|
+
|
510
|
+
@staticmethod
|
511
|
+
def normalize_txt(text: str) -> str:
|
512
|
+
"""Resolve the encoding problems and normalize the abnormal characters."""
|
513
|
+
|
514
|
+
text = (
|
515
|
+
text.encode("raw_unicode_escape")
|
516
|
+
.decode("utf-8", errors="replace_decoding_with_cp1252")
|
517
|
+
.encode("cp1252", errors="replace_encoding_with_utf8")
|
518
|
+
.decode("utf-8", errors="replace_decoding_with_cp1252")
|
519
|
+
)
|
520
|
+
text = unidecode(text)
|
521
|
+
return text
|
522
|
+
|
523
|
+
@staticmethod
|
524
|
+
def register_encoding_decoding_error_handlers() -> None:
|
525
|
+
"""Register the encoding and decoding error handlers for `utf-8` and `cp1252`."""
|
526
|
+
|
527
|
+
def replace_encoding_with_utf8(error: UnicodeError) -> Tuple[bytes, int]:
|
528
|
+
return error.object[error.start : error.end].encode("utf-8"), error.end
|
529
|
+
|
530
|
+
def replace_decoding_with_cp1252(error: UnicodeError) -> Tuple[str, int]:
|
531
|
+
return error.object[error.start : error.end].decode("cp1252"), error.end
|
532
|
+
|
533
|
+
codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8)
|
534
|
+
codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252)
|
@@ -1,11 +1,9 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
-
from omegaconf import OmegaConf
|
4
|
+
from omegaconf import DictConfig, OmegaConf
|
5
5
|
|
6
|
-
from
|
7
|
-
|
8
|
-
from ..constants import AUTOMM
|
6
|
+
from .templates import DatasetTemplates, Template, TemplateCollection
|
9
7
|
|
10
8
|
logger = logging.getLogger(__name__)
|
11
9
|
|
@@ -15,7 +13,7 @@ class TemplateEngine:
|
|
15
13
|
Class to manage the selection and use of templates.
|
16
14
|
"""
|
17
15
|
|
18
|
-
def __init__(self, template_config:
|
16
|
+
def __init__(self, template_config: DictConfig):
|
19
17
|
"""
|
20
18
|
Initialize the TemplateEngine using preset templates from existing datasets or custom templates specified in config config.data.templates, if specified.
|
21
19
|
|
@@ -28,10 +26,10 @@ class TemplateEngine:
|
|
28
26
|
self.template_config = template_config
|
29
27
|
collection = TemplateCollection()
|
30
28
|
self.all_datasets = collection.keys
|
31
|
-
self.preset_templates =
|
32
|
-
self.custom_templates =
|
33
|
-
self.num_templates =
|
34
|
-
self.template_length =
|
29
|
+
self.preset_templates = self.template_config.preset_templates
|
30
|
+
self.custom_templates = self.template_config.custom_templates
|
31
|
+
self.num_templates = self.template_config.num_templates
|
32
|
+
self.template_length = self.template_config.template_length
|
35
33
|
|
36
34
|
if self.preset_templates:
|
37
35
|
assert (
|
@@ -10,7 +10,7 @@ import random
|
|
10
10
|
import nltk
|
11
11
|
from PIL import Image, ImageEnhance, ImageOps
|
12
12
|
|
13
|
-
from ..constants import
|
13
|
+
from ..constants import IMAGE, TEXT
|
14
14
|
|
15
15
|
logger = logging.getLogger(__name__)
|
16
16
|
|
@@ -290,7 +290,7 @@ class TrivialAugment:
|
|
290
290
|
# lazy import of nlpaug due to the speed issue. See more in https://github.com/autogluon/autogluon/issues/2706
|
291
291
|
import nlpaug.augmenter.word as naw
|
292
292
|
|
293
|
-
from
|
293
|
+
from .nlpaug import InsertPunctuation
|
294
294
|
|
295
295
|
if op == "syn_replacement":
|
296
296
|
op = naw.SynonymAug(aug_src="wordnet", aug_p=scale, aug_max=None)
|