paddlex 3.0.0rc1__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. paddlex/.version +1 -1
  2. paddlex/__init__.py +1 -1
  3. paddlex/configs/modules/chart_parsing/PP-Chart2Table.yaml +13 -0
  4. paddlex/configs/modules/doc_vlm/PP-DocBee2-3B.yaml +14 -0
  5. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-L.yaml +40 -0
  6. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-M.yaml +40 -0
  7. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-S.yaml +40 -0
  8. paddlex/configs/modules/layout_detection/PP-DocBlockLayout.yaml +40 -0
  9. paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +2 -2
  10. paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +2 -2
  11. paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +2 -2
  12. paddlex/configs/modules/layout_detection/PP-DocLayout_plus-L.yaml +40 -0
  13. paddlex/configs/modules/text_detection/PP-OCRv5_mobile_det.yaml +40 -0
  14. paddlex/configs/modules/text_detection/PP-OCRv5_server_det.yaml +40 -0
  15. paddlex/configs/modules/text_recognition/PP-OCRv5_mobile_rec.yaml +39 -0
  16. paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml +39 -0
  17. paddlex/configs/modules/textline_orientation/PP-LCNet_x1_0_textline_ori.yaml +41 -0
  18. paddlex/configs/pipelines/OCR.yaml +7 -6
  19. paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +3 -1
  20. paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +91 -34
  21. paddlex/configs/pipelines/PP-StructureV3.yaml +72 -72
  22. paddlex/configs/pipelines/doc_understanding.yaml +1 -1
  23. paddlex/configs/pipelines/formula_recognition.yaml +2 -2
  24. paddlex/configs/pipelines/layout_parsing.yaml +3 -2
  25. paddlex/configs/pipelines/seal_recognition.yaml +1 -0
  26. paddlex/configs/pipelines/table_recognition.yaml +2 -1
  27. paddlex/configs/pipelines/table_recognition_v2.yaml +7 -1
  28. paddlex/hpip_links.html +20 -20
  29. paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +33 -10
  30. paddlex/inference/common/batch_sampler/image_batch_sampler.py +34 -25
  31. paddlex/inference/common/result/mixin.py +19 -12
  32. paddlex/inference/models/base/predictor/base_predictor.py +2 -8
  33. paddlex/inference/models/common/static_infer.py +11 -59
  34. paddlex/inference/models/common/tokenizer/__init__.py +2 -0
  35. paddlex/inference/models/common/tokenizer/clip_tokenizer.py +1 -1
  36. paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +2 -2
  37. paddlex/inference/models/common/tokenizer/qwen2_5_tokenizer.py +112 -0
  38. paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +7 -1
  39. paddlex/inference/models/common/tokenizer/qwen_tokenizer.py +288 -0
  40. paddlex/inference/models/common/tokenizer/tokenizer_utils.py +13 -13
  41. paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +3 -3
  42. paddlex/inference/models/common/tokenizer/vocab.py +7 -7
  43. paddlex/inference/models/common/vlm/conversion_utils.py +99 -0
  44. paddlex/inference/models/common/vlm/fusion_ops.py +205 -0
  45. paddlex/inference/models/common/vlm/generation/configuration_utils.py +1 -1
  46. paddlex/inference/models/common/vlm/generation/logits_process.py +1 -1
  47. paddlex/inference/models/common/vlm/generation/utils.py +1 -1
  48. paddlex/inference/models/common/vlm/transformers/configuration_utils.py +3 -3
  49. paddlex/inference/models/common/vlm/transformers/conversion_utils.py +3 -3
  50. paddlex/inference/models/common/vlm/transformers/model_outputs.py +2 -2
  51. paddlex/inference/models/common/vlm/transformers/model_utils.py +7 -31
  52. paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py +830 -0
  53. paddlex/inference/models/doc_vlm/modeling/__init__.py +2 -0
  54. paddlex/inference/models/doc_vlm/modeling/qwen2.py +1606 -0
  55. paddlex/inference/models/doc_vlm/modeling/qwen2_5_vl.py +3006 -0
  56. paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +0 -105
  57. paddlex/inference/models/doc_vlm/predictor.py +79 -24
  58. paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py +97 -0
  59. paddlex/inference/models/doc_vlm/processors/__init__.py +2 -0
  60. paddlex/inference/models/doc_vlm/processors/common.py +189 -0
  61. paddlex/inference/models/doc_vlm/processors/qwen2_5_vl.py +548 -0
  62. paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +21 -176
  63. paddlex/inference/models/formula_recognition/predictor.py +7 -1
  64. paddlex/inference/models/formula_recognition/processors.py +92 -79
  65. paddlex/inference/models/formula_recognition/result.py +28 -27
  66. paddlex/inference/models/image_feature/processors.py +3 -4
  67. paddlex/inference/models/keypoint_detection/predictor.py +3 -0
  68. paddlex/inference/models/object_detection/predictor.py +2 -0
  69. paddlex/inference/models/object_detection/processors.py +28 -3
  70. paddlex/inference/models/object_detection/utils.py +2 -0
  71. paddlex/inference/models/table_structure_recognition/result.py +0 -10
  72. paddlex/inference/models/text_detection/predictor.py +8 -0
  73. paddlex/inference/models/text_detection/processors.py +44 -10
  74. paddlex/inference/models/text_detection/result.py +0 -10
  75. paddlex/inference/pipelines/__init__.py +9 -5
  76. paddlex/inference/pipelines/_parallel.py +172 -0
  77. paddlex/inference/pipelines/anomaly_detection/pipeline.py +16 -6
  78. paddlex/inference/pipelines/attribute_recognition/pipeline.py +11 -1
  79. paddlex/inference/pipelines/base.py +14 -4
  80. paddlex/inference/pipelines/components/faisser.py +1 -1
  81. paddlex/inference/pipelines/doc_preprocessor/pipeline.py +53 -27
  82. paddlex/inference/pipelines/formula_recognition/pipeline.py +120 -82
  83. paddlex/inference/pipelines/formula_recognition/result.py +1 -11
  84. paddlex/inference/pipelines/image_classification/pipeline.py +16 -6
  85. paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +16 -6
  86. paddlex/inference/pipelines/instance_segmentation/pipeline.py +16 -6
  87. paddlex/inference/pipelines/keypoint_detection/pipeline.py +16 -6
  88. paddlex/inference/pipelines/layout_parsing/pipeline.py +34 -47
  89. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +893 -260
  90. paddlex/inference/pipelines/layout_parsing/result.py +4 -17
  91. paddlex/inference/pipelines/layout_parsing/result_v2.py +523 -245
  92. paddlex/inference/pipelines/layout_parsing/setting.py +87 -0
  93. paddlex/inference/pipelines/layout_parsing/utils.py +565 -1998
  94. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/__init__.py +16 -0
  95. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +1144 -0
  96. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +563 -0
  97. paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py +2 -2
  98. paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +2 -2
  99. paddlex/inference/pipelines/object_detection/pipeline.py +16 -6
  100. paddlex/inference/pipelines/ocr/pipeline.py +127 -70
  101. paddlex/inference/pipelines/ocr/result.py +19 -16
  102. paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +2 -2
  103. paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +2 -2
  104. paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +2 -2
  105. paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +2 -5
  106. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +5 -5
  107. paddlex/inference/pipelines/rotated_object_detection/pipeline.py +16 -6
  108. paddlex/inference/pipelines/seal_recognition/pipeline.py +109 -53
  109. paddlex/inference/pipelines/semantic_segmentation/pipeline.py +16 -6
  110. paddlex/inference/pipelines/small_object_detection/pipeline.py +16 -6
  111. paddlex/inference/pipelines/table_recognition/pipeline.py +26 -18
  112. paddlex/inference/pipelines/table_recognition/pipeline_v2.py +624 -53
  113. paddlex/inference/pipelines/table_recognition/result.py +1 -1
  114. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +9 -5
  115. paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +2 -2
  116. paddlex/inference/pipelines/ts_classification/pipeline.py +2 -2
  117. paddlex/inference/pipelines/ts_forecasting/pipeline.py +2 -2
  118. paddlex/inference/pipelines/video_classification/pipeline.py +2 -2
  119. paddlex/inference/pipelines/video_detection/pipeline.py +2 -2
  120. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +5 -1
  121. paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +0 -1
  122. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +0 -1
  123. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +1 -1
  124. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +6 -2
  125. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +1 -5
  126. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +4 -5
  127. paddlex/inference/serving/infra/utils.py +20 -22
  128. paddlex/inference/serving/schemas/formula_recognition.py +1 -1
  129. paddlex/inference/serving/schemas/layout_parsing.py +1 -2
  130. paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +1 -2
  131. paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +2 -2
  132. paddlex/inference/serving/schemas/pp_structurev3.py +10 -6
  133. paddlex/inference/serving/schemas/seal_recognition.py +1 -1
  134. paddlex/inference/serving/schemas/table_recognition.py +2 -6
  135. paddlex/inference/serving/schemas/table_recognition_v2.py +5 -6
  136. paddlex/inference/utils/hpi.py +8 -1
  137. paddlex/inference/utils/hpi_model_info_collection.json +81 -2
  138. paddlex/inference/utils/io/readers.py +12 -12
  139. paddlex/inference/utils/mkldnn_blocklist.py +25 -0
  140. paddlex/inference/utils/official_models.py +14 -0
  141. paddlex/inference/utils/pp_option.py +29 -8
  142. paddlex/model.py +2 -2
  143. paddlex/modules/__init__.py +1 -1
  144. paddlex/modules/anomaly_detection/evaluator.py +2 -2
  145. paddlex/modules/base/__init__.py +1 -1
  146. paddlex/modules/base/evaluator.py +5 -5
  147. paddlex/modules/base/trainer.py +1 -1
  148. paddlex/modules/doc_vlm/dataset_checker.py +2 -2
  149. paddlex/modules/doc_vlm/evaluator.py +2 -2
  150. paddlex/modules/doc_vlm/exportor.py +2 -2
  151. paddlex/modules/doc_vlm/model_list.py +1 -1
  152. paddlex/modules/doc_vlm/trainer.py +2 -2
  153. paddlex/modules/face_recognition/evaluator.py +2 -2
  154. paddlex/modules/formula_recognition/evaluator.py +5 -2
  155. paddlex/modules/formula_recognition/model_list.py +3 -0
  156. paddlex/modules/formula_recognition/trainer.py +3 -0
  157. paddlex/modules/general_recognition/evaluator.py +1 -1
  158. paddlex/modules/image_classification/evaluator.py +2 -2
  159. paddlex/modules/image_classification/model_list.py +1 -0
  160. paddlex/modules/instance_segmentation/evaluator.py +1 -1
  161. paddlex/modules/keypoint_detection/evaluator.py +1 -1
  162. paddlex/modules/m_3d_bev_detection/evaluator.py +2 -2
  163. paddlex/modules/multilabel_classification/evaluator.py +2 -2
  164. paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +4 -4
  165. paddlex/modules/object_detection/evaluator.py +2 -2
  166. paddlex/modules/object_detection/model_list.py +2 -0
  167. paddlex/modules/semantic_segmentation/evaluator.py +2 -2
  168. paddlex/modules/table_recognition/evaluator.py +2 -2
  169. paddlex/modules/text_detection/evaluator.py +2 -2
  170. paddlex/modules/text_detection/model_list.py +2 -0
  171. paddlex/modules/text_recognition/evaluator.py +2 -2
  172. paddlex/modules/text_recognition/model_list.py +2 -0
  173. paddlex/modules/ts_anomaly_detection/evaluator.py +2 -2
  174. paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  175. paddlex/modules/ts_classification/evaluator.py +2 -2
  176. paddlex/modules/ts_forecast/evaluator.py +2 -2
  177. paddlex/modules/video_classification/evaluator.py +2 -2
  178. paddlex/modules/video_detection/evaluator.py +2 -2
  179. paddlex/ops/__init__.py +2 -2
  180. paddlex/paddlex_cli.py +19 -13
  181. paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +2 -2
  182. paddlex/repo_apis/PaddleClas_api/cls/config.py +1 -1
  183. paddlex/repo_apis/PaddleClas_api/cls/model.py +1 -1
  184. paddlex/repo_apis/PaddleClas_api/cls/register.py +10 -0
  185. paddlex/repo_apis/PaddleClas_api/cls/runner.py +1 -1
  186. paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +1 -1
  187. paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +1 -1
  188. paddlex/repo_apis/PaddleDetection_api/object_det/config.py +1 -1
  189. paddlex/repo_apis/PaddleDetection_api/object_det/model.py +1 -1
  190. paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +25 -0
  191. paddlex/repo_apis/PaddleDetection_api/object_det/register.py +30 -0
  192. paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +1 -1
  193. paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +3 -3
  194. paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +5 -9
  195. paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +27 -0
  196. paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +1 -1
  197. paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +1 -1
  198. paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +1 -1
  199. paddlex/repo_apis/PaddleOCR_api/text_det/model.py +1 -1
  200. paddlex/repo_apis/PaddleOCR_api/text_det/register.py +18 -0
  201. paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +1 -1
  202. paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +3 -3
  203. paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +5 -9
  204. paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +18 -0
  205. paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +1 -1
  206. paddlex/repo_apis/PaddleSeg_api/seg/model.py +1 -1
  207. paddlex/repo_apis/PaddleSeg_api/seg/runner.py +1 -1
  208. paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +3 -3
  209. paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +2 -2
  210. paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +4 -4
  211. paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +1 -1
  212. paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +1 -1
  213. paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +1 -1
  214. paddlex/repo_apis/PaddleVideo_api/video_det/config.py +1 -1
  215. paddlex/repo_apis/PaddleVideo_api/video_det/model.py +1 -1
  216. paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +1 -1
  217. paddlex/repo_apis/base/config.py +1 -1
  218. paddlex/repo_manager/core.py +3 -3
  219. paddlex/repo_manager/meta.py +6 -2
  220. paddlex/repo_manager/repo.py +17 -16
  221. paddlex/utils/custom_device_list.py +26 -2
  222. paddlex/utils/deps.py +1 -1
  223. paddlex/utils/device.py +15 -8
  224. paddlex/utils/env.py +4 -0
  225. paddlex/utils/flags.py +2 -4
  226. paddlex/utils/fonts/__init__.py +34 -4
  227. paddlex/utils/misc.py +1 -1
  228. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/METADATA +52 -56
  229. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/RECORD +233 -206
  230. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/WHEEL +1 -1
  231. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/entry_points.txt +0 -0
  232. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/licenses/LICENSE +0 -0
  233. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
+ import re
16
17
  from typing import Any, Dict, List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
@@ -28,9 +29,11 @@ from ...common.reader import ReadImage
28
29
  from ...models.object_detection.result import DetResult
29
30
  from ...utils.hpi import HPIConfig
30
31
  from ...utils.pp_option import PaddlePredictorOption
32
+ from .._parallel import AutoParallelImageSimpleInferencePipeline
31
33
  from ..base import BasePipeline
32
34
  from ..components import CropByBoxes
33
35
  from ..doc_preprocessor.result import DocPreprocessorResult
36
+ from ..layout_parsing.utils import get_sub_regions_ocr_res
34
37
  from ..ocr.result import OCRResult
35
38
  from .result import SingleTableRecognitionResult, TableRecognitionResult
36
39
  from .table_recognition_post_processing import (
@@ -43,12 +46,9 @@ if is_dep_available("scikit-learn"):
43
46
  from sklearn.cluster import KMeans
44
47
 
45
48
 
46
- @pipeline_requires_extra("ocr")
47
- class TableRecognitionPipelineV2(BasePipeline):
49
+ class _TableRecognitionPipelineV2(BasePipeline):
48
50
  """Table Recognition Pipeline"""
49
51
 
50
- entities = ["table_recognition_v2"]
51
-
52
52
  def __init__(
53
53
  self,
54
54
  config: Dict,
@@ -64,9 +64,9 @@ class TableRecognitionPipelineV2(BasePipeline):
64
64
  device (str, optional): Device to run the predictions on. Defaults to None.
65
65
  pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
66
66
  use_hpip (bool, optional): Whether to use the high-performance
67
- inference plugin (HPIP). Defaults to False.
67
+ inference plugin (HPIP) by default. Defaults to False.
68
68
  hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
69
- The high-performance inference configuration dictionary.
69
+ The default high-performance inference configuration dictionary.
70
70
  Defaults to None.
71
71
  """
72
72
 
@@ -133,6 +133,7 @@ class TableRecognitionPipelineV2(BasePipeline):
133
133
  )
134
134
 
135
135
  self.use_ocr_model = config.get("use_ocr_model", True)
136
+ self.general_ocr_pipeline = None
136
137
  if self.use_ocr_model:
137
138
  general_ocr_config = config.get("SubPipelines", {}).get(
138
139
  "GeneralOCR",
@@ -144,8 +145,12 @@ class TableRecognitionPipelineV2(BasePipeline):
144
145
  "GeneralOCR", None
145
146
  )
146
147
 
147
- self._crop_by_boxes = CropByBoxes()
148
+ self.table_orientation_classify_model = None
149
+ self.table_orientation_classify_config = config.get("SubModules", {}).get(
150
+ "TableOrientationClassify", None
151
+ )
148
152
 
153
+ self._crop_by_boxes = CropByBoxes()
149
154
  self.batch_sampler = ImageBatchSampler(batch_size=1)
150
155
  self.img_reader = ReadImage(format="BGR")
151
156
 
@@ -539,7 +544,177 @@ class TableRecognitionPipelineV2(BasePipeline):
539
544
  final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
540
545
  return final_results
541
546
 
542
- def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
547
+ def split_ocr_bboxes_by_table_cells(
548
+ self, cells_det_results, overall_ocr_res, ori_img, k=2
549
+ ):
550
+ """
551
+ Split OCR bounding boxes based on table cell boundaries when they span multiple cells horizontally.
552
+
553
+ Args:
554
+ cells_det_results (list): List of cell bounding boxes in format [x1, y1, x2, y2]
555
+ overall_ocr_res (dict): Dictionary containing OCR results with keys:
556
+ - 'rec_boxes': OCR bounding boxes (will be converted to list)
557
+ - 'rec_texts': OCR recognized texts
558
+ ori_img (np.array): Original input image array
559
+ k (int): Threshold for determining when to split (minimum number of cells spanned)
560
+
561
+ Returns:
562
+ dict: Modified overall_ocr_res with split boxes and texts
563
+ """
564
+
565
+ def calculate_iou(box1, box2):
566
+ """
567
+ Calculate Intersection over Union (IoU) between two bounding boxes.
568
+
569
+ Args:
570
+ box1 (list): [x1, y1, x2, y2]
571
+ box2 (list): [x1, y1, x2, y2]
572
+
573
+ Returns:
574
+ float: IoU value
575
+ """
576
+ # Determine intersection coordinates
577
+ x_left = max(box1[0], box2[0])
578
+ y_top = max(box1[1], box2[1])
579
+ x_right = min(box1[2], box2[2])
580
+ y_bottom = min(box1[3], box2[3])
581
+ if x_right < x_left or y_bottom < y_top:
582
+ return 0.0
583
+ # Calculate areas
584
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
585
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
586
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
587
+ # return intersection_area / float(box1_area + box2_area - intersection_area)
588
+ return intersection_area / box2_area
589
+
590
+ def get_overlapping_cells(ocr_box, cells):
591
+ """
592
+ Find cells that overlap significantly with the OCR box (IoU > 0.5).
593
+
594
+ Args:
595
+ ocr_box (list): OCR bounding box [x1, y1, x2, y2]
596
+ cells (list): List of cell bounding boxes
597
+
598
+ Returns:
599
+ list: Indices of overlapping cells, sorted by x-coordinate
600
+ """
601
+ overlapping = []
602
+ for idx, cell in enumerate(cells):
603
+ if calculate_iou(ocr_box, cell) > 0.5:
604
+ overlapping.append(idx)
605
+ # Sort overlapping cells by their x-coordinate (left to right)
606
+ overlapping.sort(key=lambda i: cells[i][0])
607
+ return overlapping
608
+
609
+ def split_box_by_cells(ocr_box, cell_indices, cells):
610
+ """
611
+ Split OCR box vertically at cell boundaries.
612
+
613
+ Args:
614
+ ocr_box (list): Original OCR box [x1, y1, x2, y2]
615
+ cell_indices (list): Indices of cells to split by
616
+ cells (list): All cell bounding boxes
617
+
618
+ Returns:
619
+ list: List of split boxes
620
+ """
621
+ if not cell_indices:
622
+ return [ocr_box]
623
+ split_boxes = []
624
+ cells_to_split = [cells[i] for i in cell_indices]
625
+ if ocr_box[0] < cells_to_split[0][0]:
626
+ split_boxes.append(
627
+ [ocr_box[0], ocr_box[1], cells_to_split[0][0], ocr_box[3]]
628
+ )
629
+ for i in range(len(cells_to_split)):
630
+ current_cell = cells_to_split[i]
631
+ split_boxes.append(
632
+ [
633
+ max(ocr_box[0], current_cell[0]),
634
+ ocr_box[1],
635
+ min(ocr_box[2], current_cell[2]),
636
+ ocr_box[3],
637
+ ]
638
+ )
639
+ if i < len(cells_to_split) - 1:
640
+ next_cell = cells_to_split[i + 1]
641
+ if current_cell[2] < next_cell[0]:
642
+ split_boxes.append(
643
+ [current_cell[2], ocr_box[1], next_cell[0], ocr_box[3]]
644
+ )
645
+ last_cell = cells_to_split[-1]
646
+ if last_cell[2] < ocr_box[2]:
647
+ split_boxes.append([last_cell[2], ocr_box[1], ocr_box[2], ocr_box[3]])
648
+ unique_boxes = []
649
+ seen = set()
650
+ for box in split_boxes:
651
+ box_tuple = tuple(box)
652
+ if box_tuple not in seen:
653
+ seen.add(box_tuple)
654
+ unique_boxes.append(box)
655
+
656
+ return unique_boxes
657
+
658
+ # Convert OCR boxes to list if needed
659
+ if hasattr(overall_ocr_res["rec_boxes"], "tolist"):
660
+ ocr_det_results = overall_ocr_res["rec_boxes"].tolist()
661
+ else:
662
+ ocr_det_results = overall_ocr_res["rec_boxes"]
663
+ ocr_texts = overall_ocr_res["rec_texts"]
664
+
665
+ # Make copies to modify
666
+ new_boxes = []
667
+ new_texts = []
668
+
669
+ # Process each OCR box
670
+ i = 0
671
+ while i < len(ocr_det_results):
672
+ ocr_box = ocr_det_results[i]
673
+ text = ocr_texts[i]
674
+ # Find cells that significantly overlap with this OCR box
675
+ overlapping_cells = get_overlapping_cells(ocr_box, cells_det_results)
676
+ # Check if we need to split (spans >= k cells)
677
+ if len(overlapping_cells) >= k:
678
+ # Split the box at cell boundaries
679
+ split_boxes = split_box_by_cells(
680
+ ocr_box, overlapping_cells, cells_det_results
681
+ )
682
+ # Process each split box
683
+ split_texts = []
684
+ for box in split_boxes:
685
+ x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
686
+ if y2 - y1 > 1 and x2 - x1 > 1:
687
+ ocr_result = next(
688
+ self.general_ocr_pipeline.text_rec_model(
689
+ ori_img[y1:y2, x1:x2, :]
690
+ )
691
+ )
692
+ # Extract the recognized text from the OCR result
693
+ if "rec_text" in ocr_result:
694
+ result = ocr_result[
695
+ "rec_text"
696
+ ] # Assumes "rec_texts" contains a single string
697
+ else:
698
+ result = ""
699
+ else:
700
+ result = ""
701
+ split_texts.append(result)
702
+ # Add split boxes and texts to results
703
+ new_boxes.extend(split_boxes)
704
+ new_texts.extend(split_texts)
705
+ else:
706
+ # Keep original box and text
707
+ new_boxes.append(ocr_box)
708
+ new_texts.append(text)
709
+ i += 1
710
+
711
+ # Update the results dictionary
712
+ overall_ocr_res["rec_boxes"] = new_boxes
713
+ overall_ocr_res["rec_texts"] = new_texts
714
+
715
+ return overall_ocr_res
716
+
717
+ def gen_ocr_with_table_cells(self, ori_img, cells_bboxes):
543
718
  """
544
719
  Splits OCR bounding boxes by table cells and retrieves text.
545
720
 
@@ -560,20 +735,228 @@ class TableRecognitionPipelineV2(BasePipeline):
560
735
  # Extract and round up the coordinates of the bounding box.
561
736
  x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
562
737
  # Perform OCR on the defined region of the image and get the recognized text.
563
- rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
564
- # Concatenate the texts and append them to the texts_list.
565
- texts_list.append("".join(rec_te["rec_texts"]))
738
+ if y2 - y1 > 1 and x2 - x1 > 1:
739
+ rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
740
+ # Concatenate the texts and append them to the texts_list.
741
+ texts_list.append("".join(rec_te["rec_texts"]))
566
742
  # Return the list of recognized texts from each cell.
567
743
  return texts_list
568
744
 
745
+ def map_cells_to_original_image(
746
+ self, detections, table_angle, img_width, img_height
747
+ ):
748
+ """
749
+ Map bounding boxes from the rotated image back to the original image.
750
+
751
+ Parameters:
752
+ - detections: list of numpy arrays, each containing bounding box coordinates [x1, y1, x2, y2]
753
+ - table_angle: rotation angle in degrees (90, 180, or 270)
754
+ - width_orig: width of the original image (img1)
755
+ - height_orig: height of the original image (img1)
756
+
757
+ Returns:
758
+ - mapped_detections: list of numpy arrays with mapped bounding box coordinates
759
+ """
760
+
761
+ mapped_detections = []
762
+ for i in range(len(detections)):
763
+ tbx1, tby1, tbx2, tby2 = (
764
+ detections[i][0],
765
+ detections[i][1],
766
+ detections[i][2],
767
+ detections[i][3],
768
+ )
769
+ if table_angle == "270":
770
+ new_x1, new_y1 = tby1, img_width - tbx2
771
+ new_x2, new_y2 = tby2, img_width - tbx1
772
+ elif table_angle == "180":
773
+ new_x1, new_y1 = img_width - tbx2, img_height - tby2
774
+ new_x2, new_y2 = img_width - tbx1, img_height - tby1
775
+ elif table_angle == "90":
776
+ new_x1, new_y1 = img_height - tby2, tbx1
777
+ new_x2, new_y2 = img_height - tby1, tbx2
778
+ new_box = np.array([new_x1, new_y1, new_x2, new_y2])
779
+ mapped_detections.append(new_box)
780
+ return mapped_detections
781
+
782
+ def split_string_by_keywords(self, html_string):
783
+ """
784
+ Split HTML string by keywords.
785
+
786
+ Args:
787
+ html_string (str): The HTML string.
788
+ Returns:
789
+ split_html (list): The list of html keywords.
790
+ """
791
+
792
+ keywords = [
793
+ "<thead>",
794
+ "</thead>",
795
+ "<tbody>",
796
+ "</tbody>",
797
+ "<tr>",
798
+ "</tr>",
799
+ "<td>",
800
+ "<td",
801
+ ">",
802
+ "</td>",
803
+ 'colspan="2"',
804
+ 'colspan="3"',
805
+ 'colspan="4"',
806
+ 'colspan="5"',
807
+ 'colspan="6"',
808
+ 'colspan="7"',
809
+ 'colspan="8"',
810
+ 'colspan="9"',
811
+ 'colspan="10"',
812
+ 'colspan="11"',
813
+ 'colspan="12"',
814
+ 'colspan="13"',
815
+ 'colspan="14"',
816
+ 'colspan="15"',
817
+ 'colspan="16"',
818
+ 'colspan="17"',
819
+ 'colspan="18"',
820
+ 'colspan="19"',
821
+ 'colspan="20"',
822
+ 'rowspan="2"',
823
+ 'rowspan="3"',
824
+ 'rowspan="4"',
825
+ 'rowspan="5"',
826
+ 'rowspan="6"',
827
+ 'rowspan="7"',
828
+ 'rowspan="8"',
829
+ 'rowspan="9"',
830
+ 'rowspan="10"',
831
+ 'rowspan="11"',
832
+ 'rowspan="12"',
833
+ 'rowspan="13"',
834
+ 'rowspan="14"',
835
+ 'rowspan="15"',
836
+ 'rowspan="16"',
837
+ 'rowspan="17"',
838
+ 'rowspan="18"',
839
+ 'rowspan="19"',
840
+ 'rowspan="20"',
841
+ ]
842
+ regex_pattern = "|".join(re.escape(keyword) for keyword in keywords)
843
+ split_result = re.split(f"({regex_pattern})", html_string)
844
+ split_html = [part for part in split_result if part]
845
+ return split_html
846
+
847
+ def cluster_positions(self, positions, tolerance):
848
+ if not positions:
849
+ return []
850
+ positions = sorted(set(positions))
851
+ clustered = []
852
+ current_cluster = [positions[0]]
853
+ for pos in positions[1:]:
854
+ if abs(pos - current_cluster[-1]) <= tolerance:
855
+ current_cluster.append(pos)
856
+ else:
857
+ clustered.append(sum(current_cluster) / len(current_cluster))
858
+ current_cluster = [pos]
859
+ clustered.append(sum(current_cluster) / len(current_cluster))
860
+ return clustered
861
+
862
+ def trans_cells_det_results_to_html(self, cells_det_results):
863
+ """
864
+ Trans table cells bboxes to HTML.
865
+
866
+ Args:
867
+ cells_det_results (list): The table cells detection results.
868
+ Returns:
869
+ html (list): The list of html keywords.
870
+ """
871
+
872
+ tolerance = 5
873
+ x_coords = [x for cell in cells_det_results for x in (cell[0], cell[2])]
874
+ y_coords = [y for cell in cells_det_results for y in (cell[1], cell[3])]
875
+ x_positions = self.cluster_positions(x_coords, tolerance)
876
+ y_positions = self.cluster_positions(y_coords, tolerance)
877
+ x_position_to_index = {x: i for i, x in enumerate(x_positions)}
878
+ y_position_to_index = {y: i for i, y in enumerate(y_positions)}
879
+ num_rows = len(y_positions) - 1
880
+ num_cols = len(x_positions) - 1
881
+ grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
882
+ cells_info = []
883
+ cell_index = 0
884
+ cell_map = {}
885
+ for index, cell in enumerate(cells_det_results):
886
+ x1, y1, x2, y2 = cell
887
+ x1_idx = min(
888
+ range(len(x_positions)), key=lambda i: abs(x_positions[i] - x1)
889
+ )
890
+ x2_idx = min(
891
+ range(len(x_positions)), key=lambda i: abs(x_positions[i] - x2)
892
+ )
893
+ y1_idx = min(
894
+ range(len(y_positions)), key=lambda i: abs(y_positions[i] - y1)
895
+ )
896
+ y2_idx = min(
897
+ range(len(y_positions)), key=lambda i: abs(y_positions[i] - y2)
898
+ )
899
+ col_start = min(x1_idx, x2_idx)
900
+ col_end = max(x1_idx, x2_idx)
901
+ row_start = min(y1_idx, y2_idx)
902
+ row_end = max(y1_idx, y2_idx)
903
+ rowspan = row_end - row_start
904
+ colspan = col_end - col_start
905
+ if rowspan == 0:
906
+ rowspan = 1
907
+ if colspan == 0:
908
+ colspan = 1
909
+ cells_info.append(
910
+ {
911
+ "row_start": row_start,
912
+ "col_start": col_start,
913
+ "rowspan": rowspan,
914
+ "colspan": colspan,
915
+ "content": "",
916
+ }
917
+ )
918
+ for r in range(row_start, row_start + rowspan):
919
+ for c in range(col_start, col_start + colspan):
920
+ key = (r, c)
921
+ if key in cell_map:
922
+ continue
923
+ else:
924
+ cell_map[key] = index
925
+ html = "<table><tbody>"
926
+ for r in range(num_rows):
927
+ html += "<tr>"
928
+ c = 0
929
+ while c < num_cols:
930
+ key = (r, c)
931
+ if key in cell_map:
932
+ cell_index = cell_map[key]
933
+ cell_info = cells_info[cell_index]
934
+ if cell_info["row_start"] == r and cell_info["col_start"] == c:
935
+ rowspan = cell_info["rowspan"]
936
+ colspan = cell_info["colspan"]
937
+ rowspan_attr = f' rowspan="{rowspan}"' if rowspan > 1 else ""
938
+ colspan_attr = f' colspan="{colspan}"' if colspan > 1 else ""
939
+ content = cell_info["content"]
940
+ html += f"<td{rowspan_attr}{colspan_attr}>{content}</td>"
941
+ c += cell_info["colspan"]
942
+ else:
943
+ html += "<td></td>"
944
+ c += 1
945
+ html += "</tr>"
946
+ html += "</tbody></table>"
947
+ html = self.split_string_by_keywords(html)
948
+ return html
949
+
569
950
  def predict_single_table_recognition_res(
570
951
  self,
571
952
  image_array: np.ndarray,
572
953
  overall_ocr_res: OCRResult,
573
954
  table_box: list,
574
- use_table_cells_ocr_results: bool = False,
575
955
  use_e2e_wired_table_rec_model: bool = False,
576
956
  use_e2e_wireless_table_rec_model: bool = False,
957
+ use_wired_table_cells_trans_to_html: bool = False,
958
+ use_wireless_table_cells_trans_to_html: bool = False,
959
+ use_ocr_results_with_table_cells: bool = True,
577
960
  flag_find_nei_text: bool = True,
578
961
  ) -> SingleTableRecognitionResult:
579
962
  """
@@ -584,9 +967,11 @@ class TableRecognitionPipelineV2(BasePipeline):
584
967
  overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
585
968
  The overall OCR results containing text recognition information.
586
969
  table_box (list): The table box coordinates.
587
- use_table_cells_ocr_results (bool): whether to use OCR results with cells.
588
970
  use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
589
971
  use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
972
+ use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML.
973
+ use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML.
974
+ use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
590
975
  flag_find_nei_text (bool): Whether to find neighboring text.
591
976
  Returns:
592
977
  SingleTableRecognitionResult: single table recognition result.
@@ -595,20 +980,33 @@ class TableRecognitionPipelineV2(BasePipeline):
595
980
  table_cls_pred = next(self.table_cls_model(image_array))
596
981
  table_cls_result = self.extract_results(table_cls_pred, "cls")
597
982
  use_e2e_model = False
983
+ cells_trans_to_html = False
598
984
 
599
985
  if table_cls_result == "wired_table":
600
- table_structure_pred = next(self.wired_table_rec_model(image_array))
986
+ if use_wired_table_cells_trans_to_html == True:
987
+ cells_trans_to_html = True
988
+ else:
989
+ table_structure_pred = next(self.wired_table_rec_model(image_array))
601
990
  if use_e2e_wired_table_rec_model == True:
602
991
  use_e2e_model = True
992
+ if cells_trans_to_html == True:
993
+ table_structure_pred = next(self.wired_table_rec_model(image_array))
603
994
  else:
604
995
  table_cells_pred = next(
605
996
  self.wired_table_cells_detection_model(image_array, threshold=0.3)
606
997
  ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
607
998
  # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
608
999
  elif table_cls_result == "wireless_table":
609
- table_structure_pred = next(self.wireless_table_rec_model(image_array))
1000
+ if use_wireless_table_cells_trans_to_html == True:
1001
+ cells_trans_to_html = True
1002
+ else:
1003
+ table_structure_pred = next(self.wireless_table_rec_model(image_array))
610
1004
  if use_e2e_wireless_table_rec_model == True:
611
1005
  use_e2e_model = True
1006
+ if cells_trans_to_html == True:
1007
+ table_structure_pred = next(
1008
+ self.wireless_table_rec_model(image_array)
1009
+ )
612
1010
  else:
613
1011
  table_cells_pred = next(
614
1012
  self.wireless_table_cells_detection_model(
@@ -618,58 +1016,74 @@ class TableRecognitionPipelineV2(BasePipeline):
618
1016
  # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
619
1017
 
620
1018
  if use_e2e_model == False:
621
- table_structure_result = self.extract_results(
622
- table_structure_pred, "table_stru"
623
- )
624
1019
  table_cells_result, table_cells_score = self.extract_results(
625
1020
  table_cells_pred, "det"
626
1021
  )
627
1022
  table_cells_result, table_cells_score = self.cells_det_results_nms(
628
1023
  table_cells_result, table_cells_score
629
1024
  )
630
- ocr_det_boxes = self.get_region_ocr_det_boxes(
631
- overall_ocr_res["rec_boxes"].tolist(), table_box
632
- )
633
- table_cells_result = self.cells_det_results_reprocessing(
634
- table_cells_result,
635
- table_cells_score,
636
- ocr_det_boxes,
637
- len(table_structure_pred["bbox"]),
638
- )
639
- if use_table_cells_ocr_results == True:
640
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(
641
- image_array, table_cells_result
1025
+ if cells_trans_to_html == True:
1026
+ table_structure_result = self.trans_cells_det_results_to_html(
1027
+ table_cells_result
1028
+ )
1029
+ else:
1030
+ table_structure_result = self.extract_results(
1031
+ table_structure_pred, "table_stru"
1032
+ )
1033
+ ocr_det_boxes = self.get_region_ocr_det_boxes(
1034
+ overall_ocr_res["rec_boxes"].tolist(), table_box
642
1035
  )
1036
+ table_cells_result = self.cells_det_results_reprocessing(
1037
+ table_cells_result,
1038
+ table_cells_score,
1039
+ ocr_det_boxes,
1040
+ len(table_structure_pred["bbox"]),
1041
+ )
1042
+ if use_ocr_results_with_table_cells == True:
1043
+ if self.cells_split_ocr == True:
1044
+ table_box_copy = np.array([table_box])
1045
+ table_ocr_pred = get_sub_regions_ocr_res(
1046
+ overall_ocr_res, table_box_copy
1047
+ )
1048
+ table_ocr_pred = self.split_ocr_bboxes_by_table_cells(
1049
+ table_cells_result, table_ocr_pred, image_array
1050
+ )
1051
+ cells_texts_list = []
1052
+ else:
1053
+ cells_texts_list = self.gen_ocr_with_table_cells(
1054
+ image_array, table_cells_result
1055
+ )
1056
+ table_ocr_pred = {}
643
1057
  else:
1058
+ table_ocr_pred = {}
644
1059
  cells_texts_list = []
645
1060
  single_table_recognition_res = get_table_recognition_res(
646
1061
  table_box,
647
1062
  table_structure_result,
648
1063
  table_cells_result,
649
1064
  overall_ocr_res,
1065
+ table_ocr_pred,
650
1066
  cells_texts_list,
651
- use_table_cells_ocr_results,
1067
+ use_ocr_results_with_table_cells,
1068
+ self.cells_split_ocr,
652
1069
  )
653
1070
  else:
654
- if use_table_cells_ocr_results == True:
655
- table_cells_result_e2e = list(
656
- map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
1071
+ cells_texts_list = []
1072
+ use_ocr_results_with_table_cells = False
1073
+ table_cells_result_e2e = table_structure_pred["bbox"]
1074
+ table_cells_result_e2e = [
1075
+ [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e
1076
+ ]
1077
+ if cells_trans_to_html == True:
1078
+ table_structure_pred["structure"] = (
1079
+ self.trans_cells_det_results_to_html(table_cells_result_e2e)
657
1080
  )
658
- table_cells_result_e2e = [
659
- [rect[0], rect[1], rect[4], rect[5]]
660
- for rect in table_cells_result_e2e
661
- ]
662
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(
663
- image_array, table_cells_result_e2e
664
- )
665
- else:
666
- cells_texts_list = []
667
1081
  single_table_recognition_res = get_table_recognition_res_e2e(
668
1082
  table_box,
669
1083
  table_structure_pred,
670
1084
  overall_ocr_res,
671
1085
  cells_texts_list,
672
- use_table_cells_ocr_results,
1086
+ use_ocr_results_with_table_cells,
673
1087
  )
674
1088
 
675
1089
  neighbor_text = ""
@@ -698,9 +1112,12 @@ class TableRecognitionPipelineV2(BasePipeline):
698
1112
  text_det_box_thresh: Optional[float] = None,
699
1113
  text_det_unclip_ratio: Optional[float] = None,
700
1114
  text_rec_score_thresh: Optional[float] = None,
701
- use_table_cells_ocr_results: bool = False,
702
1115
  use_e2e_wired_table_rec_model: bool = False,
703
1116
  use_e2e_wireless_table_rec_model: bool = False,
1117
+ use_wired_table_cells_trans_to_html: bool = False,
1118
+ use_wireless_table_cells_trans_to_html: bool = False,
1119
+ use_table_orientation_classify: bool = True,
1120
+ use_ocr_results_with_table_cells: bool = True,
704
1121
  **kwargs,
705
1122
  ) -> TableRecognitionResult:
706
1123
  """
@@ -715,16 +1132,28 @@ class TableRecognitionPipelineV2(BasePipeline):
715
1132
  It will be used if it is not None and use_ocr_model is False.
716
1133
  layout_det_res (DetResult): The layout detection result.
717
1134
  It will be used if it is not None and use_layout_detection is False.
718
- use_table_cells_ocr_results (bool): whether to use OCR results with cells.
719
1135
  use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
720
1136
  use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
721
- flag_find_nei_text (bool): Whether to find neighboring text.
1137
+ use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML.
1138
+ use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML.
1139
+ use_table_orientation_classify (bool): Whether to use table orientation classification.
1140
+ use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
722
1141
  **kwargs: Additional keyword arguments.
723
1142
 
724
1143
  Returns:
725
1144
  TableRecognitionResult: The predicted table recognition result.
726
1145
  """
727
1146
 
1147
+ self.cells_split_ocr = True
1148
+
1149
+ if use_table_orientation_classify == True and (
1150
+ self.table_orientation_classify_model is None
1151
+ ):
1152
+ assert self.table_orientation_classify_config != None
1153
+ self.table_orientation_classify_model = self.create_model(
1154
+ self.table_orientation_classify_config
1155
+ )
1156
+
728
1157
  model_settings = self.get_model_settings(
729
1158
  use_doc_orientation_classify,
730
1159
  use_doc_unwarping,
@@ -765,50 +1194,179 @@ class TableRecognitionPipelineV2(BasePipeline):
765
1194
  text_rec_score_thresh=text_rec_score_thresh,
766
1195
  )
767
1196
  )
768
- elif use_table_cells_ocr_results == True:
1197
+ elif self.general_ocr_pipeline is None and (
1198
+ (
1199
+ use_ocr_results_with_table_cells == True
1200
+ and self.cells_split_ocr == False
1201
+ )
1202
+ or use_table_orientation_classify == True
1203
+ ):
769
1204
  assert self.general_ocr_config_bak != None
770
1205
  self.general_ocr_pipeline = self.create_pipeline(
771
1206
  self.general_ocr_config_bak
772
1207
  )
773
1208
 
1209
+ if use_table_orientation_classify == False:
1210
+ table_angle = "0"
1211
+
774
1212
  table_res_list = []
775
1213
  table_region_id = 1
1214
+
776
1215
  if not model_settings["use_layout_detection"] and layout_det_res is None:
777
- layout_det_res = {}
778
1216
  img_height, img_width = doc_preprocessor_image.shape[:2]
779
1217
  table_box = [0, 0, img_width - 1, img_height - 1]
1218
+ if use_table_orientation_classify == True:
1219
+ table_angle = next(
1220
+ self.table_orientation_classify_model(doc_preprocessor_image)
1221
+ )["label_names"][0]
1222
+ if table_angle == "90":
1223
+ doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=1)
1224
+ elif table_angle == "180":
1225
+ doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=2)
1226
+ elif table_angle == "270":
1227
+ doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=3)
1228
+ if table_angle in ["90", "180", "270"]:
1229
+ overall_ocr_res = next(
1230
+ self.general_ocr_pipeline(
1231
+ doc_preprocessor_image,
1232
+ text_det_limit_side_len=text_det_limit_side_len,
1233
+ text_det_limit_type=text_det_limit_type,
1234
+ text_det_thresh=text_det_thresh,
1235
+ text_det_box_thresh=text_det_box_thresh,
1236
+ text_det_unclip_ratio=text_det_unclip_ratio,
1237
+ text_rec_score_thresh=text_rec_score_thresh,
1238
+ )
1239
+ )
1240
+ tbx1, tby1, tbx2, tby2 = (
1241
+ table_box[0],
1242
+ table_box[1],
1243
+ table_box[2],
1244
+ table_box[3],
1245
+ )
1246
+ if table_angle == "90":
1247
+ new_x1, new_y1 = tby1, img_width - tbx2
1248
+ new_x2, new_y2 = tby2, img_width - tbx1
1249
+ elif table_angle == "180":
1250
+ new_x1, new_y1 = img_width - tbx2, img_height - tby2
1251
+ new_x2, new_y2 = img_width - tbx1, img_height - tby1
1252
+ elif table_angle == "270":
1253
+ new_x1, new_y1 = img_height - tby2, tbx1
1254
+ new_x2, new_y2 = img_height - tby1, tbx2
1255
+ table_box = [new_x1, new_y1, new_x2, new_y2]
1256
+ layout_det_res = {}
780
1257
  single_table_rec_res = self.predict_single_table_recognition_res(
781
1258
  doc_preprocessor_image,
782
1259
  overall_ocr_res,
783
1260
  table_box,
784
- use_table_cells_ocr_results,
785
1261
  use_e2e_wired_table_rec_model,
786
1262
  use_e2e_wireless_table_rec_model,
1263
+ use_wired_table_cells_trans_to_html,
1264
+ use_wireless_table_cells_trans_to_html,
1265
+ use_ocr_results_with_table_cells,
787
1266
  flag_find_nei_text=False,
788
1267
  )
789
1268
  single_table_rec_res["table_region_id"] = table_region_id
1269
+ if use_table_orientation_classify == True and table_angle != "0":
1270
+ img_height, img_width = doc_preprocessor_image.shape[:2]
1271
+ single_table_rec_res["cell_box_list"] = (
1272
+ self.map_cells_to_original_image(
1273
+ single_table_rec_res["cell_box_list"],
1274
+ table_angle,
1275
+ img_width,
1276
+ img_height,
1277
+ )
1278
+ )
790
1279
  table_res_list.append(single_table_rec_res)
791
1280
  table_region_id += 1
792
1281
  else:
793
1282
  if model_settings["use_layout_detection"]:
794
1283
  layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
795
-
1284
+ img_height, img_width = doc_preprocessor_image.shape[:2]
796
1285
  for box_info in layout_det_res["boxes"]:
797
1286
  if box_info["label"].lower() in ["table"]:
798
- crop_img_info = self._crop_by_boxes(image_array, [box_info])
1287
+ crop_img_info = self._crop_by_boxes(
1288
+ doc_preprocessor_image, [box_info]
1289
+ )
799
1290
  crop_img_info = crop_img_info[0]
800
1291
  table_box = crop_img_info["box"]
1292
+ if use_table_orientation_classify == True:
1293
+ doc_preprocessor_image_copy = doc_preprocessor_image.copy()
1294
+ table_angle = next(
1295
+ self.table_orientation_classify_model(
1296
+ crop_img_info["img"]
1297
+ )
1298
+ )["label_names"][0]
1299
+ if table_angle == "90":
1300
+ crop_img_info["img"] = np.rot90(crop_img_info["img"], k=1)
1301
+ doc_preprocessor_image_copy = np.rot90(
1302
+ doc_preprocessor_image_copy, k=1
1303
+ )
1304
+ elif table_angle == "180":
1305
+ crop_img_info["img"] = np.rot90(crop_img_info["img"], k=2)
1306
+ doc_preprocessor_image_copy = np.rot90(
1307
+ doc_preprocessor_image_copy, k=2
1308
+ )
1309
+ elif table_angle == "270":
1310
+ crop_img_info["img"] = np.rot90(crop_img_info["img"], k=3)
1311
+ doc_preprocessor_image_copy = np.rot90(
1312
+ doc_preprocessor_image_copy, k=3
1313
+ )
1314
+ if table_angle in ["90", "180", "270"]:
1315
+ overall_ocr_res = next(
1316
+ self.general_ocr_pipeline(
1317
+ doc_preprocessor_image_copy,
1318
+ text_det_limit_side_len=text_det_limit_side_len,
1319
+ text_det_limit_type=text_det_limit_type,
1320
+ text_det_thresh=text_det_thresh,
1321
+ text_det_box_thresh=text_det_box_thresh,
1322
+ text_det_unclip_ratio=text_det_unclip_ratio,
1323
+ text_rec_score_thresh=text_rec_score_thresh,
1324
+ )
1325
+ )
1326
+ tbx1, tby1, tbx2, tby2 = (
1327
+ table_box[0],
1328
+ table_box[1],
1329
+ table_box[2],
1330
+ table_box[3],
1331
+ )
1332
+ if table_angle == "90":
1333
+ new_x1, new_y1 = tby1, img_width - tbx2
1334
+ new_x2, new_y2 = tby2, img_width - tbx1
1335
+ elif table_angle == "180":
1336
+ new_x1, new_y1 = img_width - tbx2, img_height - tby2
1337
+ new_x2, new_y2 = img_width - tbx1, img_height - tby1
1338
+ elif table_angle == "270":
1339
+ new_x1, new_y1 = img_height - tby2, tbx1
1340
+ new_x2, new_y2 = img_height - tby1, tbx2
1341
+ table_box = [new_x1, new_y1, new_x2, new_y2]
801
1342
  single_table_rec_res = (
802
1343
  self.predict_single_table_recognition_res(
803
1344
  crop_img_info["img"],
804
1345
  overall_ocr_res,
805
1346
  table_box,
806
- use_table_cells_ocr_results,
807
1347
  use_e2e_wired_table_rec_model,
808
1348
  use_e2e_wireless_table_rec_model,
1349
+ use_wired_table_cells_trans_to_html,
1350
+ use_wireless_table_cells_trans_to_html,
1351
+ use_ocr_results_with_table_cells,
809
1352
  )
810
1353
  )
811
1354
  single_table_rec_res["table_region_id"] = table_region_id
1355
+ if (
1356
+ use_table_orientation_classify == True
1357
+ and table_angle != "0"
1358
+ ):
1359
+ img_height_copy, img_width_copy = (
1360
+ doc_preprocessor_image_copy.shape[:2]
1361
+ )
1362
+ single_table_rec_res["cell_box_list"] = (
1363
+ self.map_cells_to_original_image(
1364
+ single_table_rec_res["cell_box_list"],
1365
+ table_angle,
1366
+ img_width_copy,
1367
+ img_height_copy,
1368
+ )
1369
+ )
812
1370
  table_res_list.append(single_table_rec_res)
813
1371
  table_region_id += 1
814
1372
 
@@ -821,4 +1379,17 @@ class TableRecognitionPipelineV2(BasePipeline):
821
1379
  "table_res_list": table_res_list,
822
1380
  "model_settings": model_settings,
823
1381
  }
1382
+
824
1383
  yield TableRecognitionResult(single_img_res)
1384
+
1385
+
1386
+ @pipeline_requires_extra("ocr")
1387
+ class TableRecognitionPipelineV2(AutoParallelImageSimpleInferencePipeline):
1388
+ entities = ["table_recognition_v2"]
1389
+
1390
+ @property
1391
+ def _pipeline_cls(self):
1392
+ return _TableRecognitionPipelineV2
1393
+
1394
+ def _get_batch_size(self, config):
1395
+ return 1