magic-pdf 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. magic_pdf/data/batch_build_dataset.py +156 -0
  2. magic_pdf/data/dataset.py +44 -24
  3. magic_pdf/data/utils.py +108 -9
  4. magic_pdf/dict2md/ocr_mkcontent.py +4 -3
  5. magic_pdf/libs/pdf_image_tools.py +11 -6
  6. magic_pdf/libs/performance_stats.py +12 -1
  7. magic_pdf/libs/version.py +1 -1
  8. magic_pdf/model/batch_analyze.py +175 -201
  9. magic_pdf/model/doc_analyze_by_custom_model.py +137 -92
  10. magic_pdf/model/pdf_extract_kit.py +5 -38
  11. magic_pdf/model/sub_modules/language_detection/utils.py +2 -4
  12. magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py +24 -19
  13. magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +3 -1
  14. magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py +3 -1
  15. magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +31 -102
  16. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py +13 -0
  17. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py +189 -0
  18. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py +8 -0
  19. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py +163 -0
  20. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py +2351 -0
  21. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py +9 -0
  22. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py +132 -0
  23. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +132 -0
  24. magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +1084 -0
  25. magic_pdf/model/sub_modules/model_init.py +50 -37
  26. magic_pdf/model/sub_modules/model_utils.py +17 -11
  27. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/__init__.py +1 -0
  28. magic_pdf/model/sub_modules/ocr/{paddleocr → paddleocr2pytorch}/ocr_utils.py +102 -97
  29. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py +193 -0
  30. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py +39 -0
  31. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py +8 -0
  32. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py +48 -0
  33. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py +418 -0
  34. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py +25 -0
  35. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py +105 -0
  36. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py +62 -0
  37. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py +269 -0
  38. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py +290 -0
  39. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py +516 -0
  40. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py +136 -0
  41. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py +234 -0
  42. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py +638 -0
  43. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py +76 -0
  44. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py +43 -0
  45. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py +23 -0
  46. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py +109 -0
  47. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py +54 -0
  48. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py +58 -0
  49. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py +29 -0
  50. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py +456 -0
  51. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py +117 -0
  52. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py +228 -0
  53. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py +33 -0
  54. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py +20 -0
  55. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py +179 -0
  56. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py +690 -0
  57. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py +0 -0
  58. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml +383 -0
  59. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt +162 -0
  60. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt +8421 -0
  61. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt +163 -0
  62. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt +167 -0
  63. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt +95 -0
  64. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt +4399 -0
  65. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt +153 -0
  66. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt +3688 -0
  67. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt +185 -0
  68. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt +6623 -0
  69. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt +128 -0
  70. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt +151 -0
  71. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml +49 -0
  72. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/__init__.py +1 -0
  73. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/__init__.py +1 -0
  74. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py +106 -0
  75. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py +217 -0
  76. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +440 -0
  77. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py +104 -0
  78. magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py +227 -0
  79. magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +10 -18
  80. magic_pdf/pdf_parse_union_core_v2.py +112 -74
  81. magic_pdf/post_proc/para_split_v3.py +16 -13
  82. magic_pdf/pre_proc/ocr_dict_merge.py +9 -1
  83. magic_pdf/pre_proc/ocr_span_list_modify.py +51 -0
  84. magic_pdf/resources/model_config/model_configs.yaml +1 -1
  85. magic_pdf/tools/cli.py +30 -12
  86. magic_pdf/tools/common.py +90 -12
  87. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/METADATA +51 -41
  88. magic_pdf-1.3.0.dist-info/RECORD +202 -0
  89. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py +0 -204
  90. magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py +0 -213
  91. magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py +0 -37
  92. magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py +0 -71
  93. magic_pdf/resources/model_config/UniMERNet/demo.yaml +0 -46
  94. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +0 -351
  95. magic_pdf-1.2.1.dist-info/RECORD +0 -147
  96. /magic_pdf/model/sub_modules/{ocr/paddleocr/__init__.py → mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py} +0 -0
  97. /magic_pdf/model/sub_modules/{table/structeqtable → ocr/paddleocr2pytorch/pytorchocr}/__init__.py +0 -0
  98. /magic_pdf/model/sub_modules/{table/tablemaster → ocr/paddleocr2pytorch/pytorchocr/modeling}/__init__.py +0 -0
  99. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/LICENSE.md +0 -0
  100. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/WHEEL +0 -0
  101. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/entry_points.txt +0 -0
  102. {magic_pdf-1.2.1.dist-info → magic_pdf-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,6 @@
1
- import argparse
2
- import os
3
- import re
4
-
5
1
  import torch
6
- import unimernet.tasks as tasks
7
- from PIL import Image
8
2
  from torch.utils.data import DataLoader, Dataset
9
- from torchvision import transforms
10
- from unimernet.common.config import Config
11
- from unimernet.processors import load_processor
3
+ from tqdm import tqdm
12
4
 
13
5
 
14
6
  class MathDataset(Dataset):
@@ -20,55 +12,24 @@ class MathDataset(Dataset):
20
12
  return len(self.image_paths)
21
13
 
22
14
  def __getitem__(self, idx):
23
- # if not pil image, then convert to pil image
24
- if isinstance(self.image_paths[idx], str):
25
- raw_image = Image.open(self.image_paths[idx])
26
- else:
27
- raw_image = self.image_paths[idx]
15
+ raw_image = self.image_paths[idx]
28
16
  if self.transform:
29
17
  image = self.transform(raw_image)
30
18
  return image
31
19
 
32
20
 
33
- def latex_rm_whitespace(s: str):
34
- """Remove unnecessary whitespace from LaTeX code."""
35
- text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
36
- letter = "[a-zA-Z]"
37
- noletter = "[\W_^\d]"
38
- names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
39
- s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
40
- news = s
41
- while True:
42
- s = news
43
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
44
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
45
- news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
46
- if news == s:
47
- break
48
- return s
49
-
50
-
51
21
  class UnimernetModel(object):
52
22
  def __init__(self, weight_dir, cfg_path, _device_="cpu"):
53
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
54
- cfg = Config(args)
55
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
56
- cfg.config.model.model_config.model_name = weight_dir
57
- cfg.config.model.tokenizer_config.path = weight_dir
58
- task = tasks.setup_task(cfg)
59
- self.model = task.build_model(cfg)
23
+ from .unimernet_hf import UnimernetModel
24
+ if _device_.startswith("mps"):
25
+ self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
26
+ else:
27
+ self.model = UnimernetModel.from_pretrained(weight_dir)
60
28
  self.device = _device_
61
29
  self.model.to(_device_)
30
+ if not _device_.startswith("cpu"):
31
+ self.model = self.model.to(dtype=torch.float16)
62
32
  self.model.eval()
63
- vis_processor = load_processor(
64
- "formula_image_eval",
65
- cfg.config.datasets.formula_rec_eval.vis_processor.eval,
66
- )
67
- self.mfr_transform = transforms.Compose(
68
- [
69
- vis_processor,
70
- ]
71
- )
72
33
 
73
34
  def predict(self, mfd_res, image):
74
35
  formula_list = []
@@ -84,62 +45,22 @@ class UnimernetModel(object):
84
45
  "latex": "",
85
46
  }
86
47
  formula_list.append(new_item)
87
- pil_img = Image.fromarray(image)
88
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
48
+ bbox_img = image[ymin:ymax, xmin:xmax]
89
49
  mf_image_list.append(bbox_img)
90
50
 
91
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
51
+ dataset = MathDataset(mf_image_list, transform=self.model.transform)
92
52
  dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
93
53
  mfr_res = []
94
54
  for mf_img in dataloader:
55
+ mf_img = mf_img.to(dtype=self.model.dtype)
95
56
  mf_img = mf_img.to(self.device)
96
57
  with torch.no_grad():
97
58
  output = self.model.generate({"image": mf_img})
98
- mfr_res.extend(output["pred_str"])
59
+ mfr_res.extend(output["fixed_str"])
99
60
  for res, latex in zip(formula_list, mfr_res):
100
- res["latex"] = latex_rm_whitespace(latex)
61
+ res["latex"] = latex
101
62
  return formula_list
102
63
 
103
- # def batch_predict(
104
- # self, images_mfd_res: list, images: list, batch_size: int = 64
105
- # ) -> list:
106
- # images_formula_list = []
107
- # mf_image_list = []
108
- # backfill_list = []
109
- # for image_index in range(len(images_mfd_res)):
110
- # mfd_res = images_mfd_res[image_index]
111
- # pil_img = Image.fromarray(images[image_index])
112
- # formula_list = []
113
- #
114
- # for xyxy, conf, cla in zip(
115
- # mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
116
- # ):
117
- # xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
118
- # new_item = {
119
- # "category_id": 13 + int(cla.item()),
120
- # "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
121
- # "score": round(float(conf.item()), 2),
122
- # "latex": "",
123
- # }
124
- # formula_list.append(new_item)
125
- # bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
126
- # mf_image_list.append(bbox_img)
127
- #
128
- # images_formula_list.append(formula_list)
129
- # backfill_list += formula_list
130
- #
131
- # dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
132
- # dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
133
- # mfr_res = []
134
- # for mf_img in dataloader:
135
- # mf_img = mf_img.to(self.device)
136
- # with torch.no_grad():
137
- # output = self.model.generate({"image": mf_img})
138
- # mfr_res.extend(output["pred_str"])
139
- # for res, latex in zip(backfill_list, mfr_res):
140
- # res["latex"] = latex_rm_whitespace(latex)
141
- # return images_formula_list
142
-
143
64
  def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
144
65
  images_formula_list = []
145
66
  mf_image_list = []
@@ -149,7 +70,7 @@ class UnimernetModel(object):
149
70
  # Collect images with their original indices
150
71
  for image_index in range(len(images_mfd_res)):
151
72
  mfd_res = images_mfd_res[image_index]
152
- pil_img = Image.fromarray(images[image_index])
73
+ np_array_image = images[image_index]
153
74
  formula_list = []
154
75
 
155
76
  for idx, (xyxy, conf, cla) in enumerate(zip(
@@ -163,7 +84,7 @@ class UnimernetModel(object):
163
84
  "latex": "",
164
85
  }
165
86
  formula_list.append(new_item)
166
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
87
+ bbox_img = np_array_image[ymin:ymax, xmin:xmax]
167
88
  area = (xmax - xmin) * (ymax - ymin)
168
89
 
169
90
  curr_idx = len(mf_image_list)
@@ -182,22 +103,30 @@ class UnimernetModel(object):
182
103
  index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
183
104
 
184
105
  # Create dataset with sorted images
185
- dataset = MathDataset(sorted_images, transform=self.mfr_transform)
106
+ dataset = MathDataset(sorted_images, transform=self.model.transform)
186
107
  dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
187
108
 
188
109
  # Process batches and store results
189
110
  mfr_res = []
190
- for mf_img in dataloader:
191
- mf_img = mf_img.to(self.device)
192
- with torch.no_grad():
193
- output = self.model.generate({"image": mf_img})
194
- mfr_res.extend(output["pred_str"])
111
+ # for mf_img in dataloader:
112
+
113
+ with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
114
+ for index, mf_img in enumerate(dataloader):
115
+ mf_img = mf_img.to(dtype=self.model.dtype)
116
+ mf_img = mf_img.to(self.device)
117
+ with torch.no_grad():
118
+ output = self.model.generate({"image": mf_img})
119
+ mfr_res.extend(output["fixed_str"])
120
+
121
+ # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
122
+ current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
123
+ pbar.update(current_batch_size)
195
124
 
196
125
  # Restore original order
197
126
  unsorted_results = [""] * len(mfr_res)
198
127
  for new_idx, latex in enumerate(mfr_res):
199
128
  original_idx = index_mapping[new_idx]
200
- unsorted_results[original_idx] = latex_rm_whitespace(latex)
129
+ unsorted_results[original_idx] = latex
201
130
 
202
131
  # Fill results back
203
132
  for res, latex in zip(backfill_list, unsorted_results):
@@ -0,0 +1,13 @@
1
+ from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
2
+ from .unimer_mbart import UnimerMBartConfig, UnimerMBartModel, UnimerMBartForCausalLM
3
+ from .modeling_unimernet import UnimernetModel
4
+
5
+ __all__ = [
6
+ "UnimerSwinConfig",
7
+ "UnimerSwinModel",
8
+ "UnimerSwinImageProcessor",
9
+ "UnimerMBartConfig",
10
+ "UnimerMBartModel",
11
+ "UnimerMBartForCausalLM",
12
+ "UnimernetModel",
13
+ ]
@@ -0,0 +1,189 @@
1
+ import os
2
+ import re
3
+ import warnings
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from ftfy import fix_text
8
+
9
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
10
+ from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
11
+ from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import logger as base_model_logger
12
+
13
+ from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
14
+ from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
15
+
16
+ AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
17
+ AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
18
+ AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
19
+ AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
20
+
21
+
22
+ # TODO: rewrite tokenizer
23
+ class TokenizerWrapper:
24
+ def __init__(self, tokenizer):
25
+ self.tokenizer = tokenizer
26
+ self.pad_token_id = self.tokenizer.pad_token_id
27
+ self.bos_token_id = self.tokenizer.bos_token_id
28
+ self.eos_token_id = self.tokenizer.eos_token_id
29
+
30
+ def __len__(self):
31
+ return len(self.tokenizer)
32
+
33
+ def tokenize(self, text, **kwargs):
34
+ return self.tokenizer(
35
+ text,
36
+ return_token_type_ids=False,
37
+ return_tensors="pt",
38
+ padding="longest",
39
+ truncation=True,
40
+ **kwargs,
41
+ )
42
+
43
+ def token2str(self, tokens) -> list:
44
+ generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
45
+ generated_text = [fix_text(text) for text in generated_text]
46
+ return generated_text
47
+
48
+ def detokenize(self, tokens):
49
+ toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
50
+ for b in range(len(toks)):
51
+ for i in reversed(range(len(toks[b]))):
52
+ if toks[b][i] is None:
53
+ toks[b][i] = ''
54
+ toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
55
+ if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
56
+ del toks[b][i]
57
+ return toks
58
+
59
+
60
+ def latex_rm_whitespace(s: str):
61
+ """Remove unnecessary whitespace from LaTeX code.
62
+ """
63
+ text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
64
+ letter = r'[a-zA-Z]'
65
+ noletter = r'[\W_^\d]'
66
+ names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
67
+ s = re.sub(text_reg, lambda _: str(names.pop(0)), s)
68
+ news = s
69
+ while True:
70
+ s = news
71
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
72
+ news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
73
+ news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
74
+ if news == s:
75
+ break
76
+ return s
77
+
78
+
79
+ class UnimernetModel(VisionEncoderDecoderModel):
80
+ def __init__(
81
+ self,
82
+ config: Optional[PretrainedConfig] = None,
83
+ encoder: Optional[PreTrainedModel] = None,
84
+ decoder: Optional[PreTrainedModel] = None,
85
+ ):
86
+ # VisionEncoderDecoderModel's checking log has bug, disable for temp.
87
+ base_model_logger.disabled = True
88
+ try:
89
+ super().__init__(config, encoder, decoder)
90
+ finally:
91
+ base_model_logger.disabled = False
92
+
93
+ if not config or not hasattr(config, "_name_or_path"):
94
+ raise RuntimeError("config._name_or_path is required by UnimernetModel.")
95
+
96
+ model_path = config._name_or_path
97
+ self.transform = UnimerSwinImageProcessor()
98
+ self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
99
+ self._post_check()
100
+
101
+ def _post_check(self):
102
+ tokenizer = self.tokenizer
103
+
104
+ if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
105
+ warnings.warn(
106
+ f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
107
+ f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
108
+ f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
109
+ tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
110
+
111
+ assert self.config.decoder.vocab_size == len(tokenizer)
112
+ assert self.config.decoder_start_token_id == tokenizer.bos_token_id
113
+ assert self.config.pad_token_id == tokenizer.pad_token_id
114
+
115
+ @classmethod
116
+ def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
117
+ config = VisionEncoderDecoderConfig.from_pretrained(model_path)
118
+ config._name_or_path = model_path
119
+ config.encoder = UnimerSwinConfig(**vars(config.encoder))
120
+ config.decoder = UnimerMBartConfig(**vars(config.decoder))
121
+
122
+ encoder = UnimerSwinModel(config.encoder)
123
+ decoder = UnimerMBartForCausalLM(config.decoder)
124
+ model = cls(config, encoder, decoder)
125
+
126
+ # load model weights
127
+ model_file_path = os.path.join(model_path, model_filename)
128
+ checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
129
+ state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
130
+ if not state_dict:
131
+ raise RuntimeError("state_dict is empty.")
132
+ if state_dict_strip_prefix:
133
+ state_dict = {
134
+ k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
135
+ for k, v in state_dict.items()
136
+ }
137
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
138
+ if len(unexpected_keys) > 0:
139
+ warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
140
+ if len(missing_keys) > 0:
141
+ raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
142
+ return model
143
+
144
+ def forward_bak(self, samples):
145
+ pixel_values, text = samples["image"], samples["text_input"]
146
+
147
+ text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
148
+ decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
149
+
150
+ num_channels = pixel_values.shape[1]
151
+ if num_channels == 1:
152
+ pixel_values = pixel_values.repeat(1, 3, 1, 1)
153
+
154
+ labels = decoder_input_ids * 1
155
+ labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
156
+
157
+ loss = self.model(
158
+ pixel_values=pixel_values,
159
+ decoder_input_ids=decoder_input_ids[:, :-1],
160
+ decoder_attention_mask=decoder_attention_mask[:, :-1],
161
+ labels=labels[:, 1:],
162
+ ).loss
163
+ return {"loss": loss}
164
+
165
+ def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95):
166
+ pixel_values = samples["image"]
167
+ num_channels = pixel_values.shape[1]
168
+ if num_channels == 1:
169
+ pixel_values = pixel_values.repeat(1, 3, 1, 1)
170
+
171
+ kwargs = {}
172
+ if do_sample:
173
+ kwargs["temperature"] = temperature
174
+ kwargs["top_p"] = top_p
175
+
176
+ outputs = super().generate(
177
+ pixel_values=pixel_values,
178
+ max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
179
+ decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
180
+ do_sample=do_sample,
181
+ **kwargs,
182
+ )
183
+
184
+ outputs = outputs[:, 1:].cpu().numpy()
185
+ pred_tokens = self.tokenizer.detokenize(outputs)
186
+ pred_str = self.tokenizer.token2str(outputs)
187
+ fixed_str = [latex_rm_whitespace(s) for s in pred_str]
188
+ return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}
189
+
@@ -0,0 +1,8 @@
1
+ from .configuration_unimer_mbart import UnimerMBartConfig
2
+ from .modeling_unimer_mbart import UnimerMBartModel, UnimerMBartForCausalLM
3
+
4
+ __all__ = [
5
+ "UnimerMBartConfig",
6
+ "UnimerMBartModel",
7
+ "UnimerMBartForCausalLM",
8
+ ]
@@ -0,0 +1,163 @@
1
+ # coding=utf-8
2
+ # Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """UnimerMBART model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class UnimerMBartConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the MBART
29
+ [facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 50265):
37
+ Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
39
+ d_model (`int`, *optional*, defaults to 1024):
40
+ Dimensionality of the layers and the pooler layer.
41
+ qk_squeeze (`int`, *optional*, defaults to 2):
42
+ Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
43
+ Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
44
+ thereby accelerating the computation of attention.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ scale_embedding (`bool`, *optional*, defaults to `False`):
80
+ Scale embeddings by diving by sqrt(d_model).
81
+ use_cache (`bool`, *optional*, defaults to `True`):
82
+ Whether or not the model should return the last key/values attentions (not used by all models)
83
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
84
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
85
+ `eos_token_id`.
86
+
87
+ Example:
88
+
89
+ ```python
90
+ >>> from transformers import MBartConfig, MBartModel
91
+
92
+ >>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
93
+ >>> configuration = MBartConfig()
94
+
95
+ >>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
96
+ >>> model = MBartModel(configuration)
97
+
98
+ >>> # Accessing the model configuration
99
+ >>> configuration = model.config
100
+ ```"""
101
+
102
+ model_type = "unimer-mbart"
103
+ keys_to_ignore_at_inference = ["past_key_values"]
104
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size=50265,
109
+ max_position_embeddings=1024,
110
+ encoder_layers=12,
111
+ encoder_ffn_dim=4096,
112
+ encoder_attention_heads=16,
113
+ decoder_layers=12,
114
+ decoder_ffn_dim=4096,
115
+ decoder_attention_heads=16,
116
+ encoder_layerdrop=0.0,
117
+ decoder_layerdrop=0.0,
118
+ use_cache=True,
119
+ is_encoder_decoder=True,
120
+ activation_function="gelu",
121
+ d_model=1024,
122
+ qk_squeeze=2,
123
+ dropout=0.1,
124
+ attention_dropout=0.0,
125
+ activation_dropout=0.0,
126
+ init_std=0.02,
127
+ classifier_dropout=0.0,
128
+ scale_embedding=False,
129
+ pad_token_id=1,
130
+ bos_token_id=0,
131
+ eos_token_id=2,
132
+ forced_eos_token_id=2,
133
+ **kwargs,
134
+ ):
135
+ self.vocab_size = vocab_size
136
+ self.max_position_embeddings = max_position_embeddings
137
+ self.d_model = d_model
138
+ self.qk_squeeze = qk_squeeze
139
+ self.encoder_ffn_dim = encoder_ffn_dim
140
+ self.encoder_layers = encoder_layers
141
+ self.encoder_attention_heads = encoder_attention_heads
142
+ self.decoder_ffn_dim = decoder_ffn_dim
143
+ self.decoder_layers = decoder_layers
144
+ self.decoder_attention_heads = decoder_attention_heads
145
+ self.dropout = dropout
146
+ self.attention_dropout = attention_dropout
147
+ self.activation_dropout = activation_dropout
148
+ self.activation_function = activation_function
149
+ self.init_std = init_std
150
+ self.encoder_layerdrop = encoder_layerdrop
151
+ self.decoder_layerdrop = decoder_layerdrop
152
+ self.classifier_dropout = classifier_dropout
153
+ self.use_cache = use_cache
154
+ self.num_hidden_layers = encoder_layers
155
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
156
+ super().__init__(
157
+ pad_token_id=pad_token_id,
158
+ bos_token_id=bos_token_id,
159
+ eos_token_id=eos_token_id,
160
+ is_encoder_decoder=is_encoder_decoder,
161
+ forced_eos_token_id=forced_eos_token_id,
162
+ **kwargs,
163
+ )