paddlex 3.0.1__py3-none-any.whl → 3.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. paddlex/.version +1 -1
  2. paddlex/inference/models/base/predictor/base_predictor.py +2 -0
  3. paddlex/inference/models/common/static_infer.py +20 -14
  4. paddlex/inference/models/common/ts/funcs.py +19 -8
  5. paddlex/inference/models/formula_recognition/predictor.py +1 -1
  6. paddlex/inference/models/formula_recognition/processors.py +2 -2
  7. paddlex/inference/models/text_recognition/result.py +1 -1
  8. paddlex/inference/pipelines/layout_parsing/layout_objects.py +859 -0
  9. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +144 -205
  10. paddlex/inference/pipelines/layout_parsing/result_v2.py +13 -272
  11. paddlex/inference/pipelines/layout_parsing/setting.py +1 -0
  12. paddlex/inference/pipelines/layout_parsing/utils.py +108 -312
  13. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +302 -247
  14. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +156 -104
  15. paddlex/inference/pipelines/ocr/result.py +2 -2
  16. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +1 -1
  17. paddlex/inference/serving/basic_serving/_app.py +47 -13
  18. paddlex/inference/serving/infra/utils.py +22 -17
  19. paddlex/inference/utils/hpi.py +60 -25
  20. paddlex/inference/utils/hpi_model_info_collection.json +627 -204
  21. paddlex/inference/utils/misc.py +20 -0
  22. paddlex/inference/utils/mkldnn_blocklist.py +36 -2
  23. paddlex/inference/utils/official_models.py +126 -5
  24. paddlex/inference/utils/pp_option.py +81 -21
  25. paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +12 -2
  26. paddlex/ops/__init__.py +6 -3
  27. paddlex/utils/deps.py +2 -2
  28. paddlex/utils/device.py +4 -19
  29. paddlex/utils/download.py +10 -7
  30. paddlex/utils/flags.py +9 -0
  31. paddlex/utils/subclass_register.py +2 -2
  32. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/METADATA +307 -162
  33. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/RECORD +37 -35
  34. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/WHEEL +1 -1
  35. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/entry_points.txt +1 -0
  36. {paddlex-3.0.1.dist-info/licenses → paddlex-3.0.3.dist-info}/LICENSE +0 -0
  37. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,14 @@ from typing import Dict, List, Tuple
17
17
 
18
18
  import numpy as np
19
19
 
20
- from ..result_v2 import LayoutParsingBlock, LayoutParsingRegion
21
- from ..setting import BLOCK_LABEL_MAP
20
+ from ..layout_objects import LayoutBlock, LayoutRegion
21
+ from ..setting import BLOCK_LABEL_MAP, XYCUT_SETTINGS
22
22
  from ..utils import calculate_overlap_ratio, calculate_projection_overlap_ratio
23
23
  from .utils import (
24
24
  calculate_discontinuous_projection,
25
+ euclidean_insert,
26
+ find_local_minima_flat_regions,
27
+ get_blocks_by_direction_interval,
25
28
  get_cut_blocks,
26
29
  insert_child_blocks,
27
30
  manhattan_insert,
@@ -31,16 +34,16 @@ from .utils import (
31
34
  reference_insert,
32
35
  shrink_overlapping_boxes,
33
36
  sort_normal_blocks,
34
- split_projection_profile,
35
37
  update_doc_title_child_blocks,
36
38
  update_paragraph_title_child_blocks,
39
+ update_region_child_blocks,
37
40
  update_vision_child_blocks,
38
41
  weighted_distance_insert,
39
42
  )
40
43
 
41
44
 
42
45
  def pre_process(
43
- region: LayoutParsingRegion,
46
+ region: LayoutRegion,
44
47
  ) -> List:
45
48
  """
46
49
  Preprocess the layout for sorting purposes.
@@ -63,10 +66,11 @@ def pre_process(
63
66
  "sub_paragraph_title",
64
67
  "doc_title_text",
65
68
  "vision_title",
69
+ "sub_region",
66
70
  ]
67
71
  pre_cut_block_idxes = []
68
72
  block_map = region.block_map
69
- blocks: List[LayoutParsingBlock] = list(block_map.values())
73
+ blocks: List[LayoutBlock] = list(block_map.values())
70
74
  for block in blocks:
71
75
  if block.order_label not in mask_labels:
72
76
  update_region_label(block, region)
@@ -83,7 +87,6 @@ def pre_process(
83
87
  ) / 2
84
88
  center_offset = abs(block_center - region.direction_center_coordinate)
85
89
  is_centered = center_offset <= tolerance_len
86
-
87
90
  if is_centered:
88
91
  pre_cut_block_idxes.append(block.index)
89
92
 
@@ -121,60 +124,83 @@ def pre_process(
121
124
  block.secondary_direction_start_coordinate
122
125
  )
123
126
  cut_coordinates.append(block.secondary_direction_end_coordinate)
124
- secondary_discontinuous = calculate_discontinuous_projection(
125
- all_boxes, direction=region.direction
127
+ secondary_check_bboxes = np.array(
128
+ [
129
+ block.bbox
130
+ for block in blocks
131
+ if block.order_label not in mask_labels + ["vision"]
132
+ ]
126
133
  )
127
- if len(secondary_discontinuous) == 1:
128
- if not discontinuous:
129
- discontinuous = calculate_discontinuous_projection(
130
- all_boxes, direction=cut_direction
131
- )
132
- current_interval = discontinuous[0]
133
- for interval in discontinuous[1:]:
134
- gap_len = interval[0] - current_interval[1]
135
- if gap_len >= region.text_line_height * 3:
136
- cut_coordinates.append(current_interval[1])
137
- elif gap_len > region.text_line_height * 1.2:
138
- (pre_blocks, post_blocks) = get_cut_blocks(
139
- list(block_map.values()), cut_direction, [current_interval[1]], []
140
- )
141
- pre_bboxes = np.array([block.bbox for block in pre_blocks])
142
- post_bboxes = np.array([block.bbox for block in post_blocks])
143
- projection_index = 1 if cut_direction == "horizontal" else 0
144
- pre_projection = projection_by_bboxes(pre_bboxes, projection_index)
145
- post_projection = projection_by_bboxes(post_bboxes, projection_index)
146
- pre_projection_min = np.min(pre_projection)
147
- post_projection_min = np.min(post_projection)
148
- pre_projection_min += 5 if pre_projection_min != 0 else 0
149
- post_projection_min += 5 if post_projection_min != 0 else 0
150
- pre_intervals = split_projection_profile(
151
- pre_projection, pre_projection_min, 1
152
- )
153
- post_intervals = split_projection_profile(
154
- post_projection, post_projection_min, 1
134
+ if len(secondary_check_bboxes) > 0 or blocks[0].label == "region":
135
+ secondary_discontinuous = calculate_discontinuous_projection(
136
+ secondary_check_bboxes, direction=region.direction
137
+ )
138
+ if len(secondary_discontinuous) == 1 or blocks[0].label == "region":
139
+ if not discontinuous:
140
+ discontinuous = calculate_discontinuous_projection(
141
+ all_boxes, direction=cut_direction
155
142
  )
156
- pre_gap_boxes = []
157
- if pre_intervals is not None:
158
- for start, end in zip(*pre_intervals):
159
- bbox = [0] * 4
160
- bbox[projection_index] = start
161
- bbox[projection_index + 2] = end
162
- pre_gap_boxes.append(bbox)
163
- post_gap_boxes = []
164
- if post_intervals is not None:
165
- for start, end in zip(*post_intervals):
166
- bbox = [0] * 4
167
- bbox[projection_index] = start
168
- bbox[projection_index + 2] = end
169
- post_gap_boxes.append(bbox)
170
- max_gap_boxes_num = max(len(pre_gap_boxes), len(post_gap_boxes))
171
- if max_gap_boxes_num > 0:
172
- discontinuous_intervals = calculate_discontinuous_projection(
173
- pre_gap_boxes + post_gap_boxes, direction=region.direction
143
+ current_interval = discontinuous[0]
144
+ pre_cut_coordinates = [
145
+ cood for cood in cut_coordinates if cood < current_interval[1]
146
+ ]
147
+ if not pre_cut_coordinates:
148
+ pre_cut_coordinate = 0
149
+ else:
150
+ pre_cut_coordinate = max(pre_cut_coordinates)
151
+ pre_cut_coordinate = max(current_interval[0], pre_cut_coordinate)
152
+ for interval in discontinuous[1:]:
153
+ gap_len = interval[0] - current_interval[1]
154
+ if (
155
+ gap_len >= region.text_line_height * 3
156
+ or blocks[0].label == "region"
157
+ ):
158
+ cut_coordinates.append(current_interval[1])
159
+ elif gap_len > region.text_line_height * 1.2:
160
+ pre_blocks = get_blocks_by_direction_interval(
161
+ list(block_map.values()),
162
+ pre_cut_coordinate,
163
+ current_interval[1],
164
+ cut_direction,
165
+ )
166
+ post_blocks = get_blocks_by_direction_interval(
167
+ list(block_map.values()),
168
+ current_interval[1],
169
+ interval[1],
170
+ cut_direction,
174
171
  )
175
- if len(discontinuous_intervals) != max_gap_boxes_num:
176
- cut_coordinates.append(current_interval[1])
177
- current_interval = interval
172
+ pre_bboxes = np.array([block.bbox for block in pre_blocks])
173
+ post_bboxes = np.array([block.bbox for block in post_blocks])
174
+ projection_index = 1 if cut_direction == "horizontal" else 0
175
+ pre_projection = projection_by_bboxes(pre_bboxes, projection_index)
176
+ post_projection = projection_by_bboxes(
177
+ post_bboxes, projection_index
178
+ )
179
+ pre_intervals = find_local_minima_flat_regions(pre_projection)
180
+ post_intervals = find_local_minima_flat_regions(post_projection)
181
+ pre_gap_boxes = []
182
+ if pre_intervals is not None:
183
+ for start, end in pre_intervals:
184
+ bbox = [0] * 4
185
+ bbox[projection_index] = start
186
+ bbox[projection_index + 2] = end
187
+ pre_gap_boxes.append(bbox)
188
+ post_gap_boxes = []
189
+ if post_intervals is not None:
190
+ for start, end in post_intervals:
191
+ bbox = [0] * 4
192
+ bbox[projection_index] = start
193
+ bbox[projection_index + 2] = end
194
+ post_gap_boxes.append(bbox)
195
+ max_gap_boxes_num = max(len(pre_gap_boxes), len(post_gap_boxes))
196
+ if max_gap_boxes_num > 0:
197
+ discontinuous_intervals = calculate_discontinuous_projection(
198
+ pre_gap_boxes + post_gap_boxes, direction=region.direction
199
+ )
200
+ if len(discontinuous_intervals) != max_gap_boxes_num:
201
+ pre_cut_coordinate = current_interval[1]
202
+ cut_coordinates.append(current_interval[1])
203
+ current_interval = interval
178
204
  cut_list = get_cut_blocks(blocks, cut_direction, cut_coordinates, mask_labels)
179
205
  pre_cut_list.extend(cut_list)
180
206
  if region.direction == "vertical":
@@ -184,14 +210,14 @@ def pre_process(
184
210
 
185
211
 
186
212
  def update_region_label(
187
- block: LayoutParsingBlock,
188
- region: LayoutParsingRegion,
213
+ block: LayoutBlock,
214
+ region: LayoutRegion,
189
215
  ) -> None:
190
216
  """
191
217
  Update the region label of a block based on its label and match the block with its children.
192
218
 
193
219
  Args:
194
- blocks (List[LayoutParsingBlock]): The list of blocks to process.
220
+ blocks (List[LayoutBlock]): The list of blocks to process.
195
221
  config (Dict[str, Any]): The configuration dictionary containing the necessary information.
196
222
  block_idx (int): The index of the current block being processed.
197
223
 
@@ -210,17 +236,18 @@ def update_region_label(
210
236
  elif block.label in BLOCK_LABEL_MAP["vision_labels"]:
211
237
  block.order_label = "vision"
212
238
  block.num_of_lines = 1
213
- block.direction = region.direction
214
- block.update_direction_info()
239
+ block.update_direction(region.direction)
215
240
  elif block.label in BLOCK_LABEL_MAP["footer_labels"]:
216
241
  block.order_label = "footer"
217
242
  elif block.label in BLOCK_LABEL_MAP["unordered_labels"]:
218
243
  block.order_label = "unordered"
244
+ elif block.label == "region":
245
+ block.order_label = "region"
219
246
  else:
220
247
  block.order_label = "normal_text"
221
248
 
222
249
  # only vision and doc title block can have child block
223
- if block.order_label not in ["vision", "doc_title", "paragraph_title"]:
250
+ if block.order_label not in ["vision", "doc_title", "paragraph_title", "region"]:
224
251
  return
225
252
 
226
253
  # match doc title text block
@@ -232,10 +259,12 @@ def update_region_label(
232
259
  # match vision title block and vision footnote block
233
260
  elif block.order_label == "vision":
234
261
  update_vision_child_blocks(block, region)
262
+ elif block.order_label == "region":
263
+ update_region_child_blocks(block, region)
235
264
 
236
265
 
237
266
  def get_layout_structure(
238
- blocks: List[LayoutParsingBlock],
267
+ blocks: List[LayoutBlock],
239
268
  region_direction: str,
240
269
  region_secondary_direction: str,
241
270
  ) -> Tuple[List[Dict[str, any]], bool]:
@@ -263,11 +292,11 @@ def get_layout_structure(
263
292
  continue
264
293
 
265
294
  bbox_iou = calculate_overlap_ratio(block.bbox, ref_block.bbox)
266
- if bbox_iou > 0:
295
+ if bbox_iou:
267
296
  if ref_block.order_label == "vision":
268
297
  ref_block.order_label = "cross_layout"
269
298
  break
270
- if block.order_label == "vision" or block.area < ref_block.area:
299
+ if bbox_iou > 0.1 and block.area < ref_block.area:
271
300
  block.order_label = "cross_layout"
272
301
  break
273
302
 
@@ -320,13 +349,19 @@ def get_layout_structure(
320
349
  and ref_match_projection_iou == 0
321
350
  and secondary_direction_ref_match_projection_overlap_ratio > 0
322
351
  ):
323
- if block.order_label == "vision" or (
352
+ if block.order_label in ["vision", "region"] or (
324
353
  ref_block.order_label == "normal_text"
325
354
  and second_ref_block.order_label == "normal_text"
326
- and ref_block.text_line_width
327
- > ref_block.text_line_height * 5
328
- and second_ref_block.text_line_width
329
- > second_ref_block.text_line_height * 5
355
+ and ref_block.long_side_length
356
+ > ref_block.text_line_height
357
+ * XYCUT_SETTINGS.get(
358
+ "cross_layout_ref_text_block_words_num_threshold", 8
359
+ )
360
+ and second_ref_block.long_side_length
361
+ > second_ref_block.text_line_height
362
+ * XYCUT_SETTINGS.get(
363
+ "cross_layout_ref_text_block_words_num_threshold", 8
364
+ )
330
365
  ):
331
366
  block.order_label = (
332
367
  "cross_reference"
@@ -374,20 +409,20 @@ def sort_by_xycut(
374
409
 
375
410
 
376
411
  def match_unsorted_blocks(
377
- sorted_blocks: List[LayoutParsingBlock],
378
- unsorted_blocks: List[LayoutParsingBlock],
379
- region: LayoutParsingRegion,
380
- ) -> List[LayoutParsingBlock]:
412
+ sorted_blocks: List[LayoutBlock],
413
+ unsorted_blocks: List[LayoutBlock],
414
+ region: LayoutRegion,
415
+ ) -> List[LayoutBlock]:
381
416
  """
382
417
  Match special blocks with the sorted blocks based on their region labels.
383
418
  Args:
384
- sorted_blocks (List[LayoutParsingBlock]): Sorted blocks to be matched.
385
- unsorted_blocks (List[LayoutParsingBlock]): Unsorted blocks to be matched.
419
+ sorted_blocks (List[LayoutBlock]): Sorted blocks to be matched.
420
+ unsorted_blocks (List[LayoutBlock]): Unsorted blocks to be matched.
386
421
  config (Dict): Configuration dictionary containing various parameters.
387
422
  median_width (int): Median width value used for calculations.
388
423
 
389
424
  Returns:
390
- List[LayoutParsingBlock]: The updated sorted blocks after matching special blocks.
425
+ List[LayoutBlock]: The updated sorted blocks after matching special blocks.
391
426
  """
392
427
  distance_type_map = {
393
428
  "cross_layout": weighted_distance_insert,
@@ -398,6 +433,7 @@ def match_unsorted_blocks(
398
433
  "cross_reference": reference_insert,
399
434
  "unordered": manhattan_insert,
400
435
  "other": manhattan_insert,
436
+ "region": euclidean_insert,
401
437
  }
402
438
 
403
439
  unsorted_blocks = sort_normal_blocks(
@@ -407,17 +443,19 @@ def match_unsorted_blocks(
407
443
  region.direction,
408
444
  )
409
445
  for idx, block in enumerate(unsorted_blocks):
410
- order_label = block.order_label
446
+ order_label = block.order_label if block.label != "region" else "region"
411
447
  if idx == 0 and order_label == "doc_title":
412
448
  sorted_blocks.insert(0, block)
413
449
  continue
414
- sorted_blocks = distance_type_map[order_label](block, sorted_blocks, region)
450
+ sorted_blocks = distance_type_map[order_label](
451
+ block=block, sorted_blocks=sorted_blocks, region=region
452
+ )
415
453
  return sorted_blocks
416
454
 
417
455
 
418
456
  def xycut_enhanced(
419
- region: LayoutParsingRegion,
420
- ) -> LayoutParsingRegion:
457
+ region: LayoutRegion,
458
+ ) -> LayoutRegion:
421
459
  """
422
460
  xycut_enhance function performs the following steps:
423
461
  1. Preprocess the input blocks by extracting headers, footers, and pre-cut blocks.
@@ -428,34 +466,34 @@ def xycut_enhanced(
428
466
  6. Return the ordered result list.
429
467
 
430
468
  Args:
431
- blocks (List[LayoutParsingBlock]): Input blocks to be processed.
469
+ blocks (List[LayoutBlock]): Input blocks to be processed.
432
470
 
433
471
  Returns:
434
- List[LayoutParsingBlock]: Ordered result list after processing.
472
+ List[LayoutBlock]: Ordered result list after processing.
435
473
  """
436
474
  if len(region.block_map) == 0:
437
475
  return []
438
476
 
439
- pre_cut_list: List[List[LayoutParsingBlock]] = pre_process(region)
440
- final_order_res_list: List[LayoutParsingBlock] = []
477
+ pre_cut_list: List[List[LayoutBlock]] = pre_process(region)
478
+ final_order_res_list: List[LayoutBlock] = []
441
479
 
442
- header_blocks: List[LayoutParsingBlock] = [
480
+ header_blocks: List[LayoutBlock] = [
443
481
  region.block_map[idx] for idx in region.header_block_idxes
444
482
  ]
445
- unordered_blocks: List[LayoutParsingBlock] = [
483
+ unordered_blocks: List[LayoutBlock] = [
446
484
  region.block_map[idx] for idx in region.unordered_block_idxes
447
485
  ]
448
- footer_blocks: List[LayoutParsingBlock] = [
486
+ footer_blocks: List[LayoutBlock] = [
449
487
  region.block_map[idx] for idx in region.footer_block_idxes
450
488
  ]
451
489
 
452
- header_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
490
+ header_blocks: List[LayoutBlock] = sort_normal_blocks(
453
491
  header_blocks, region.text_line_height, region.text_line_width, region.direction
454
492
  )
455
- footer_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
493
+ footer_blocks: List[LayoutBlock] = sort_normal_blocks(
456
494
  footer_blocks, region.text_line_height, region.text_line_width, region.direction
457
495
  )
458
- unordered_blocks: List[LayoutParsingBlock] = sort_normal_blocks(
496
+ unordered_blocks: List[LayoutBlock] = sort_normal_blocks(
459
497
  unordered_blocks,
460
498
  region.text_line_height,
461
499
  region.text_line_width,
@@ -463,16 +501,26 @@ def xycut_enhanced(
463
501
  )
464
502
  final_order_res_list.extend(header_blocks)
465
503
 
466
- unsorted_blocks: List[LayoutParsingBlock] = []
467
- sorted_blocks_by_pre_cuts: List[LayoutParsingBlock] = []
504
+ unsorted_blocks: List[LayoutBlock] = []
505
+ sorted_blocks_by_pre_cuts: List[LayoutBlock] = []
468
506
  for pre_cut_blocks in pre_cut_list:
469
- sorted_blocks: List[LayoutParsingBlock] = []
470
- doc_title_blocks: List[LayoutParsingBlock] = []
471
- xy_cut_blocks: List[LayoutParsingBlock] = []
507
+ sorted_blocks: List[LayoutBlock] = []
508
+ doc_title_blocks: List[LayoutBlock] = []
509
+ xy_cut_blocks: List[LayoutBlock] = []
472
510
 
473
- get_layout_structure(
474
- pre_cut_blocks, region.direction, region.secondary_direction
475
- )
511
+ if pre_cut_blocks and pre_cut_blocks[0].label == "region":
512
+ block_bboxes = np.array([block.bbox for block in pre_cut_blocks])
513
+ discontinuous = calculate_discontinuous_projection(
514
+ block_bboxes, direction=region.direction
515
+ )
516
+ if len(discontinuous) == 1:
517
+ get_layout_structure(
518
+ pre_cut_blocks, region.direction, region.secondary_direction
519
+ )
520
+ else:
521
+ get_layout_structure(
522
+ pre_cut_blocks, region.direction, region.secondary_direction
523
+ )
476
524
 
477
525
  # Get xy cut blocks and add other blocks in special_block_map
478
526
  for block in pre_cut_blocks:
@@ -494,8 +542,6 @@ def xycut_enhanced(
494
542
  discontinuous = calculate_discontinuous_projection(
495
543
  block_bboxes, direction=region.direction
496
544
  )
497
- if len(discontinuous) > 1:
498
- xy_cut_blocks = [block for block in xy_cut_blocks]
499
545
  blocks_to_sort = deepcopy(xy_cut_blocks)
500
546
  if region.direction == "vertical":
501
547
  for block in blocks_to_sort:
@@ -526,7 +572,7 @@ def xycut_enhanced(
526
572
  )
527
573
  )
528
574
  blocks_to_sort = shrink_overlapping_boxes(
529
- blocks_to_sort, region.direction
575
+ blocks_to_sort, region.secondary_direction
530
576
  )
531
577
  block_bboxes = np.array([block.bbox for block in blocks_to_sort])
532
578
  sorted_indexes = sort_by_xycut(
@@ -536,13 +582,19 @@ def xycut_enhanced(
536
582
  sorted_blocks = [
537
583
  region.block_map[blocks_to_sort[i].index] for i in sorted_indexes
538
584
  ]
539
-
540
585
  sorted_blocks = match_unsorted_blocks(
541
586
  sorted_blocks,
542
587
  doc_title_blocks,
543
588
  region=region,
544
589
  )
545
590
 
591
+ if unsorted_blocks and unsorted_blocks[0].label == "region":
592
+ sorted_blocks = match_unsorted_blocks(
593
+ sorted_blocks,
594
+ unsorted_blocks,
595
+ region=region,
596
+ )
597
+ unsorted_blocks = []
546
598
  sorted_blocks_by_pre_cuts.extend(sorted_blocks)
547
599
 
548
600
  final_sorted_blocks = match_unsorted_blocks(
@@ -206,10 +206,10 @@ def draw_box_txt_fine(
206
206
  np.ndarray: An image with the text drawn in the specified box.
207
207
  """
208
208
  box_height = int(
209
- math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
209
+ math.sqrt(float(box[0][0] - box[3][0]) ** 2 + float(box[0][1] - box[3][1]) ** 2)
210
210
  )
211
211
  box_width = int(
212
- math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
212
+ math.sqrt(float(box[0][0] - box[1][0]) ** 2 + float(box[0][1] - box[1][1]) ** 2)
213
213
  )
214
214
 
215
215
  if box_height > 2 * box_width and box_height > 30:
@@ -638,7 +638,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
638
638
 
639
639
  for image_array in self.img_reader([input]):
640
640
 
641
- image_string = cv2.imencode(".jpg", image_array)[1].tostring()
641
+ image_string = cv2.imencode(".jpg", image_array)[1].tobytes()
642
642
  image_base64 = base64.b64encode(image_string).decode("utf-8")
643
643
  result = {}
644
644
  for key in key_list:
@@ -15,6 +15,8 @@
15
15
  import asyncio
16
16
  import contextlib
17
17
  import json
18
+ from queue import Queue
19
+ from threading import Thread
18
20
  from typing import (
19
21
  Any,
20
22
  AsyncGenerator,
@@ -74,16 +76,22 @@ class PipelineWrapper(Generic[PipelineT]):
74
76
  def __init__(self, pipeline: PipelineT) -> None:
75
77
  super().__init__()
76
78
  self._pipeline = pipeline
77
- self._lock = asyncio.Lock()
79
+ # HACK: We work around a bug in Paddle Inference by performing all
80
+ # inference in the same thread.
81
+ self._queue = Queue()
82
+ self._closed = False
83
+ self._loop = asyncio.get_running_loop()
84
+ self._thread = Thread(target=self._worker, daemon=False)
85
+ self._thread.start()
78
86
 
79
87
  @property
80
88
  def pipeline(self) -> PipelineT:
81
89
  return self._pipeline
82
90
 
83
91
  async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
84
- def _infer() -> List[Any]:
92
+ def _infer(*args, **kwargs) -> List[Any]:
85
93
  output: list = []
86
- with contextlib.closing(self._pipeline(*args, **kwargs)) as it:
94
+ with contextlib.closing(self._pipeline.predict(*args, **kwargs)) as it:
87
95
  for item in it:
88
96
  if _is_error(item):
89
97
  raise fastapi.HTTPException(
@@ -93,11 +101,34 @@ class PipelineWrapper(Generic[PipelineT]):
93
101
 
94
102
  return output
95
103
 
96
- return await self.call(_infer)
104
+ return await self.call(_infer, *args, **kwargs)
97
105
 
98
106
  async def call(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
99
- async with self._lock:
100
- return await call_async(func, *args, **kwargs)
107
+ if self._closed:
108
+ raise RuntimeError("`PipelineWrapper` has already been closed")
109
+ fut = self._loop.create_future()
110
+ self._queue.put((func, args, kwargs, fut))
111
+ return await fut
112
+
113
+ async def close(self):
114
+ if not self._closed:
115
+ self._queue.put(None)
116
+ await call_async(self._thread.join)
117
+ self._closed = True
118
+
119
+ def _worker(self):
120
+ while not self._closed:
121
+ item = self._queue.get()
122
+ if item is None:
123
+ break
124
+ func, args, kwargs, fut = item
125
+ try:
126
+ result = func(*args, **kwargs)
127
+ self._loop.call_soon_threadsafe(fut.set_result, result)
128
+ except Exception as e:
129
+ self._loop.call_soon_threadsafe(fut.set_exception, e)
130
+ finally:
131
+ self._queue.task_done()
101
132
 
102
133
 
103
134
  @class_requires_deps("aiohttp")
@@ -141,14 +172,17 @@ def create_app(
141
172
  @contextlib.asynccontextmanager
142
173
  async def _app_lifespan(app: "fastapi.FastAPI") -> AsyncGenerator[None, None]:
143
174
  ctx.pipeline = PipelineWrapper[PipelineT](pipeline)
144
- if app_aiohttp_session:
145
- async with aiohttp.ClientSession(
146
- cookie_jar=aiohttp.DummyCookieJar()
147
- ) as aiohttp_session:
148
- ctx.aiohttp_session = aiohttp_session
175
+ try:
176
+ if app_aiohttp_session:
177
+ async with aiohttp.ClientSession(
178
+ cookie_jar=aiohttp.DummyCookieJar()
179
+ ) as aiohttp_session:
180
+ ctx.aiohttp_session = aiohttp_session
181
+ yield
182
+ else:
149
183
  yield
150
- else:
151
- yield
184
+ finally:
185
+ await ctx.pipeline.close()
152
186
 
153
187
  # Should we control API versions?
154
188
  app = fastapi.FastAPI(lifespan=_app_lifespan)
@@ -18,6 +18,7 @@ import io
18
18
  import mimetypes
19
19
  import re
20
20
  import tempfile
21
+ import threading
21
22
  import uuid
22
23
  from functools import partial
23
24
  from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar, Union, overload
@@ -176,29 +177,33 @@ def base64_encode(data: bytes) -> str:
176
177
  return base64.b64encode(data).decode("ascii")
177
178
 
178
179
 
180
+ _lock = threading.Lock()
181
+
182
+
179
183
  @function_requires_deps("pypdfium2", "opencv-contrib-python")
180
184
  def read_pdf(
181
185
  bytes_: bytes, max_num_imgs: Optional[int] = None
182
186
  ) -> Tuple[List[np.ndarray], PDFInfo]:
183
187
  images: List[np.ndarray] = []
184
188
  page_info_list: List[PDFPageInfo] = []
185
- doc = pdfium.PdfDocument(bytes_)
186
- for page in doc:
187
- if max_num_imgs is not None and len(images) >= max_num_imgs:
188
- break
189
- # TODO: Do not always use zoom=2.0
190
- zoom = 2.0
191
- deg = 0
192
- image = page.render(scale=zoom, rotation=deg).to_pil()
193
- image = image.convert("RGB")
194
- image = np.array(image)
195
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
196
- images.append(image)
197
- page_info = PDFPageInfo(
198
- width=image.shape[1],
199
- height=image.shape[0],
200
- )
201
- page_info_list.append(page_info)
189
+ with _lock:
190
+ doc = pdfium.PdfDocument(bytes_)
191
+ for page in doc:
192
+ if max_num_imgs is not None and len(images) >= max_num_imgs:
193
+ break
194
+ # TODO: Do not always use zoom=2.0
195
+ zoom = 2.0
196
+ deg = 0
197
+ image = page.render(scale=zoom, rotation=deg).to_pil()
198
+ image = image.convert("RGB")
199
+ image = np.array(image)
200
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
201
+ images.append(image)
202
+ page_info = PDFPageInfo(
203
+ width=image.shape[1],
204
+ height=image.shape[0],
205
+ )
206
+ page_info_list.append(page_info)
202
207
  pdf_info = PDFInfo(
203
208
  numPages=len(page_info_list),
204
209
  pages=page_info_list,