autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,16 +1,18 @@
1
1
  import logging
2
2
  import os
3
+ import re
3
4
  import warnings
4
5
  from typing import Any, Dict, List, Optional, Union
5
6
 
6
7
  import numpy as np
7
8
  from numpy.typing import NDArray
8
9
  from omegaconf import DictConfig
10
+ from tokenizers import pre_tokenizers
9
11
  from torch import nn
10
12
 
11
- from ..constants import AUTOMM, NER_ANNOTATION, NER_TEXT, TEXT, TEXT_NER
13
+ from ..constants import NER_ANNOTATION, NER_TEXT, TEXT, TEXT_NER
14
+ from ..models.utils import get_pretrained_tokenizer
12
15
  from .collator import PadCollator, StackCollator
13
- from .utils import process_ner_annotations, tokenize_ner_text
14
16
 
15
17
  logger = logging.getLogger(__name__)
16
18
 
@@ -124,12 +126,12 @@ class NerProcessor:
124
126
  ner_text = all_features[text_column]
125
127
  if is_training or annotation_column is not None:
126
128
  ner_annotation = all_features[annotation_column]
127
- label, col_tokens, token_to_word_mappings, word_offsets = process_ner_annotations(
129
+ label, col_tokens, token_to_word_mappings, word_offsets = self.process_ner_annotations(
128
130
  ner_annotation, ner_text, self.entity_map, self.tokenizer
129
131
  )
130
132
  ret.update({self.label_key: label})
131
133
  else:
132
- col_tokens, token_to_word_mappings, word_offsets = tokenize_ner_text(ner_text, self.tokenizer)
134
+ col_tokens, token_to_word_mappings, word_offsets = self.tokenize_ner_text(ner_text, self.tokenizer)
133
135
  ret.update({self.label_key: np.array([], dtype=np.int32)})
134
136
 
135
137
  ret.update(
@@ -144,6 +146,192 @@ class NerProcessor:
144
146
 
145
147
  return ret
146
148
 
149
+ @classmethod
150
+ def process_ner_annotations(cls, ner_annotations, ner_text, entity_map, tokenizer, is_eval=False):
151
+ """
152
+ Generate token-level/word-level labels with given text and NER annotations.
153
+
154
+ Parameters
155
+ ----------
156
+ ner_annotations
157
+ The NER annotations.
158
+ ner_text
159
+ The corresponding raw text.
160
+ entity_map
161
+ The map between tags and tag indexes. e.g., {"PER":2, "LOC":3}.
162
+ tokenizer
163
+ The tokenizer to be used.
164
+ is_eval
165
+ Whether it is for evaluation or not, default: False
166
+
167
+ Returns
168
+ -------
169
+ Token-level/word-level labels and text features.
170
+ """
171
+ col_tokens, token_to_word_mappings, word_offsets = cls.tokenize_ner_text(ner_text, tokenizer)
172
+ num_words = len(set(token_to_word_mappings)) - 1
173
+ word_label = [1] * num_words
174
+ # TODO: Potentially optimize word label generation via binary search
175
+ b_prefix = "B-"
176
+ i_prefix = "I-"
177
+ for annot in ner_annotations:
178
+ custom_offset = annot[0]
179
+ custom_label = annot[1]
180
+ is_start_word = True
181
+ for idx, word_offset in enumerate(word_offsets[:num_words, :]):
182
+ # support multiple words in an annotated offset range.
183
+ # Allow partial overlapping between custom annotations and pretokenized words.
184
+ if (word_offset[0] < custom_offset[1]) and (custom_offset[0] < word_offset[1]):
185
+ if not (
186
+ re.match(b_prefix, custom_label, re.IGNORECASE)
187
+ or re.match(i_prefix, custom_label, re.IGNORECASE)
188
+ ):
189
+ if is_start_word and b_prefix + custom_label in entity_map:
190
+ word_label[idx] = entity_map[b_prefix + custom_label]
191
+ is_start_word = False
192
+ elif i_prefix + custom_label in entity_map:
193
+ word_label[idx] = entity_map[i_prefix + custom_label]
194
+ else:
195
+ if custom_label in entity_map:
196
+ word_label[idx] = entity_map[custom_label]
197
+
198
+ token_label = [0] * len(col_tokens.input_ids)
199
+ temp = set()
200
+ counter = 0
201
+ for idx, token_to_word in enumerate(token_to_word_mappings):
202
+ if token_to_word != -1 and token_to_word not in temp:
203
+ temp.add(token_to_word)
204
+ token_label[idx] = word_label[counter]
205
+ counter += 1
206
+ if not is_eval:
207
+ label = token_label # return token-level labels for training
208
+ else:
209
+ label = word_label # return word-level labels for evaluation
210
+
211
+ return label, col_tokens, token_to_word_mappings, word_offsets
212
+
213
+ @classmethod
214
+ def tokenize_ner_text(cls, text, tokenizer):
215
+ """
216
+ Tokenization process for the NER task. It will be used for the token-level label generation
217
+ and the input text tokenization.
218
+
219
+ Parameters
220
+ ----------
221
+ text
222
+ The raw text data.
223
+ tokenizer
224
+ The tokenizer to be used.
225
+
226
+ Returns
227
+ -------
228
+ The output of tokenizer and word offsets.
229
+ """
230
+ # pre-tokenization is required for NER token-level label generation.
231
+ words_with_offsets = pre_tokenizers.BertPreTokenizer().pre_tokenize_str(text)
232
+ words_with_offsets = (
233
+ cls.is_space_counted(words_with_offsets) if len(words_with_offsets) > 1 else words_with_offsets
234
+ )
235
+ words = [word for word, offset in words_with_offsets]
236
+ word_offsets = np.array([[offset[0], offset[1]] for word, offset in words_with_offsets], dtype=np.int32)
237
+ col_tokens = tokenizer(
238
+ words,
239
+ is_split_into_words=True,
240
+ return_offsets_mapping=True,
241
+ padding="max_length",
242
+ truncation=True,
243
+ max_length=tokenizer.model_max_length,
244
+ return_token_type_ids=True,
245
+ )
246
+ offset_mapping = np.array(col_tokens.offset_mapping, dtype=np.int32)
247
+ if len(words_with_offsets) > 1:
248
+ if offset_mapping.shape[0] > len(words):
249
+ word_offsets = np.pad(word_offsets, ((0, offset_mapping.shape[0] - len(words)), (0, 0)), "constant")
250
+ # token to word mappings: it will tell us which token belongs to which word.
251
+ token_to_word_mappings = [i if i != None else -1 for i in col_tokens.word_ids()]
252
+ if len(set(token_to_word_mappings)) != len(words) + 1:
253
+ warnings.warn(f"The token to word mappings are incorrect!")
254
+ else:
255
+ # If pre_tokenizer does not give word offsets, use word_ids and offset_mappings instead.
256
+ word_offsets = np.append(offset_mapping[1:], [[0, 0]], axis=0)
257
+ word_idx = np.arange(len(col_tokens.word_ids()) - col_tokens.word_ids().count(None))
258
+ token_to_word_mappings = [
259
+ val + word_idx[idx - 1] if val != None else -1 for idx, val in enumerate(col_tokens.word_ids())
260
+ ]
261
+
262
+ return col_tokens, token_to_word_mappings, word_offsets
263
+
264
+ @staticmethod
265
+ def is_space_counted(words_with_offsets):
266
+ """
267
+ Some tokenizers will count space into words for example.
268
+ Given text: 'hello world', normal bert will output: [('hello', (0, 5)), ('world', (6, 11))]
269
+ while some checkpoint will output: [('▁hello', (0, 5)), ('▁world', (5, 11))]
270
+ This will lead to inconsistency issue during labelling, details can be found here:
271
+ https://github.com/huggingface/transformers/issues/18111
272
+
273
+ This function will check whether space is counted or not and realign the offset.
274
+ """
275
+ offset0, offset1 = [], []
276
+ for word, offset in words_with_offsets:
277
+ offset0.append(offset[0])
278
+ offset1.append(offset[1])
279
+
280
+ realign = []
281
+ if offset0[1:] == offset1[:-1]: # space are counted
282
+ realign = [words_with_offsets[0]]
283
+ for word, offset in words_with_offsets[1:]:
284
+ if word.startswith("▁"): # it is "Lower One Eighth Block" (U+2581) rather than lower line (U+005F).
285
+ realign.append((word, (offset[0] + 1, offset[1])))
286
+ else:
287
+ realign.append((word, offset))
288
+
289
+ if realign:
290
+ return realign
291
+ else:
292
+ return words_with_offsets
293
+
294
+ def save_tokenizer(
295
+ self,
296
+ path: str,
297
+ ):
298
+ """
299
+ Save the text tokenizer and record its relative paths, e.g, hf_text.
300
+
301
+ Parameters
302
+ ----------
303
+ path
304
+ The root path of saving.
305
+
306
+ """
307
+ save_path = os.path.join(path, self.prefix)
308
+ self.tokenizer.save_pretrained(save_path)
309
+ self.tokenizer = self.prefix
310
+
311
+ def load_tokenizer(
312
+ self,
313
+ path: str,
314
+ ):
315
+ """
316
+ Load saved text tokenizers. If text/ner processors already have tokenizers,
317
+ then do nothing.
318
+
319
+ Parameters
320
+ ----------
321
+ path
322
+ The root path of loading.
323
+
324
+ Returns
325
+ -------
326
+ A list of text/ner processors with tokenizers loaded.
327
+ """
328
+ if isinstance(self.tokenizer, str):
329
+ load_path = os.path.join(path, self.tokenizer)
330
+ self.tokenizer = get_pretrained_tokenizer(
331
+ tokenizer_name=self.tokenizer_name,
332
+ checkpoint_name=load_path,
333
+ )
334
+
147
335
  def __call__(
148
336
  self,
149
337
  all_features: Dict[str, Union[NDArray, list]],
@@ -1,3 +1,5 @@
1
+ import logging
2
+ import random
1
3
  from typing import Any, Dict, List, Optional, Union
2
4
 
3
5
  import numpy as np
@@ -6,6 +8,8 @@ from torch import nn
6
8
  from ..constants import COLUMN, NUMERICAL
7
9
  from .collator import StackCollator
8
10
 
11
+ logger = logging.getLogger(__name__)
12
+
9
13
 
10
14
  class NumericalProcessor:
11
15
  """
@@ -19,6 +23,7 @@ class NumericalProcessor:
19
23
  model: nn.Module,
20
24
  merge: Optional[str] = "concat",
21
25
  requires_column_info: bool = False,
26
+ dropout: Optional[float] = 0,
22
27
  ):
23
28
  """
24
29
  Parameters
@@ -33,9 +38,16 @@ class NumericalProcessor:
33
38
  requires_column_info
34
39
  Whether to require feature column information in dataloader.
35
40
  """
41
+ logger.debug(f"initializing numerical processor for model {model.prefix}")
36
42
  self.prefix = model.prefix
37
43
  self.merge = merge
38
44
  self.requires_column_info = requires_column_info
45
+ self.numerical_fill_values = model.numerical_fill_values
46
+ self.dropout = dropout
47
+ assert 0 <= self.dropout <= 1
48
+ if self.dropout > 0:
49
+ logger.debug(f"numerical value dropout probability: {self.dropout}")
50
+ logger.debug(f"dropped values will be replaced by {self.numerical_fill_values}")
39
51
 
40
52
  @property
41
53
  def numerical_key(self):
@@ -67,6 +79,7 @@ class NumericalProcessor:
67
79
  def process_one_sample(
68
80
  self,
69
81
  numerical_features: Dict[str, float],
82
+ is_training: bool,
70
83
  ) -> Dict:
71
84
  """
72
85
  Process one sample's numerical features.
@@ -76,6 +89,8 @@ class NumericalProcessor:
76
89
  ----------
77
90
  numerical_features
78
91
  Numerical features of one sample.
92
+ is_training
93
+ Whether to do processing in the training mode.
79
94
 
80
95
  Returns
81
96
  -------
@@ -87,6 +102,15 @@ class NumericalProcessor:
87
102
  for i, col_name in enumerate(numerical_features.keys()):
88
103
  ret[f"{self.numerical_column_prefix}_{col_name}"] = i
89
104
 
105
+ if is_training and self.dropout > 0:
106
+ numerical_features_copy = dict()
107
+ for k, v in numerical_features.items():
108
+ if random.uniform(0, 1) <= self.dropout:
109
+ numerical_features_copy[k] = self.numerical_fill_values[k]
110
+ else:
111
+ numerical_features_copy[k] = v
112
+ numerical_features = numerical_features_copy
113
+
90
114
  if self.merge == "concat":
91
115
  ret[self.numerical_key] = np.array(list(numerical_features.values()), dtype=np.float32)
92
116
  else:
@@ -97,7 +121,7 @@ class NumericalProcessor:
97
121
  def __call__(
98
122
  self,
99
123
  numerical_features: Dict[str, float],
100
- feature_modalities: Dict[str, Union[int, float, list]],
124
+ sub_dtypes: Dict[str, str],
101
125
  is_training: bool,
102
126
  ) -> Dict:
103
127
  """
@@ -107,13 +131,16 @@ class NumericalProcessor:
107
131
  ----------
108
132
  numerical_features
109
133
  Numerical features of one sample.
110
- feature_modalities
111
- The modality of the feature columns.
134
+ sub_dtypes
135
+ The sub data types of all numerical columns.
112
136
  is_training
113
- Whether to do processing in the training mode. This unused flag is for the API compatibility.
137
+ Whether to do processing in the training mode.
114
138
 
115
139
  Returns
116
140
  -------
117
141
  A dictionary containing one sample's processed numerical features.
118
142
  """
119
- return self.process_one_sample(numerical_features)
143
+ return self.process_one_sample(
144
+ numerical_features=numerical_features,
145
+ is_training=is_training,
146
+ )
@@ -1,44 +1,32 @@
1
1
  import logging
2
2
  import random
3
- import warnings
4
- from io import BytesIO
5
3
  from typing import Dict, List, Optional, Union
6
4
 
7
5
  import numpy as np
8
6
  import PIL
9
7
  import torch
10
- from omegaconf import DictConfig
11
8
  from PIL import Image, ImageFile
12
9
  from torch import nn
13
10
  from torchvision import transforms
14
11
 
15
- from .utils import construct_image_processor, image_mean_std
16
-
17
- try:
18
- from torchvision.transforms import InterpolationMode
19
-
20
- BICUBIC = InterpolationMode.BICUBIC
21
- except ImportError:
22
- BICUBIC = PIL.Image.BICUBIC
23
-
24
12
  from ..constants import (
25
13
  CLASS_LABEL,
26
14
  COLUMN,
27
15
  IMAGE,
28
- IMAGE_BYTEARRAY,
29
16
  IMAGE_VALID_NUM,
30
17
  LABEL,
31
18
  MASK_LABEL,
32
19
  SEMANTIC_SEGMENTATION_GT,
33
20
  SEMANTIC_SEGMENTATION_IMG,
34
21
  )
35
- from .collator import ListCollator, PadCollator, StackCollator
22
+ from .collator import ListCollator, PadCollator
23
+ from .process_image import ImageProcessor
36
24
 
37
25
  logger = logging.getLogger(__name__)
38
26
  ImageFile.LOAD_TRUNCATED_IMAGES = True
39
27
 
40
28
 
41
- class SemanticSegImageProcessor:
29
+ class SemanticSegImageProcessor(ImageProcessor):
42
30
  """
43
31
  Prepare image data for the model specified by "prefix". For multiple models requiring image data,
44
32
  we need to create a ImageProcessor for each related model so that they will have independent input.
@@ -51,7 +39,6 @@ class SemanticSegImageProcessor:
51
39
  gt_transforms: List[str],
52
40
  train_transforms: Optional[List[str]] = None,
53
41
  val_transforms: Optional[List[str]] = None,
54
- norm_type: Optional[str] = None,
55
42
  max_img_num_per_col: Optional[int] = 1,
56
43
  missing_value_strategy: Optional[str] = "skip",
57
44
  requires_column_info: bool = False,
@@ -70,15 +57,6 @@ class SemanticSegImageProcessor:
70
57
  A list of image transforms used in training for data augmentation. Note that the transform order matters.
71
58
  val_transforms
72
59
  A list of image transforms used in validation/test/prediction. Note that the transform order matters.
73
- norm_type
74
- How to normalize an image. We now support:
75
- - inception
76
- Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
77
- - imagenet
78
- Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
79
- - clip
80
- Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
81
- std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
82
60
  max_img_num_per_col
83
61
  The maximum number of images one sample can have.
84
62
  missing_value_strategy
@@ -98,7 +76,8 @@ class SemanticSegImageProcessor:
98
76
  self.requires_column_info = requires_column_info
99
77
 
100
78
  self.size = model.image_size
101
- self.mean, self.std = image_mean_std(norm_type)
79
+ self.mean = model.image_mean
80
+ self.std = model.image_std
102
81
  self.normalization = transforms.Normalize(self.mean, self.std)
103
82
  self.num_classes = model.num_classes
104
83
  self.ignore_label = ignore_label
@@ -110,10 +89,10 @@ class SemanticSegImageProcessor:
110
89
  self.max_img_num_per_col = max_img_num_per_col
111
90
  logger.debug(f"max_img_num_per_col: {max_img_num_per_col}")
112
91
 
113
- self.img_processor = construct_image_processor(
92
+ self.img_processor = self.construct_image_processor(
114
93
  image_transforms=self.img_transforms, size=self.size, normalization=self.normalization
115
94
  )
116
- self.gt_processor = construct_image_processor(
95
+ self.gt_processor = self.construct_image_processor(
117
96
  image_transforms=self.gt_transforms, size=self.size, normalization=None
118
97
  )
119
98
  self.train_transforms = self.get_train_transforms(train_transforms)
@@ -325,3 +304,19 @@ class SemanticSegImageProcessor:
325
304
  if trans_mode == "random_horizontal_flip":
326
305
  train_trans.append(transforms.RandomHorizontalFlip(1.0))
327
306
  return transforms.Compose(train_trans)
307
+
308
+ def __getstate__(self):
309
+ odict = self.__dict__.copy() # get attribute dictionary
310
+ del odict["img_processor"]
311
+ del odict["gt_processor"]
312
+
313
+ return odict
314
+
315
+ def __setstate__(self, state):
316
+ self.__dict__ = state
317
+ self.img_processor = self.construct_image_processor(
318
+ image_transforms=self.img_transforms, size=self.size, normalization=self.normalization
319
+ )
320
+ self.gt_processor = self.construct_image_processor(
321
+ image_transforms=self.gt_transforms, size=self.size, normalization=None
322
+ )