paddlex 3.0.0rc1__py3-none-any.whl → 3.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paddlex/.version +1 -1
- paddlex/__init__.py +1 -1
- paddlex/configs/modules/chart_parsing/PP-Chart2Table.yaml +13 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee2-3B.yaml +14 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-L.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-M.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-S.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocBlockLayout.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +2 -2
- paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +2 -2
- paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +2 -2
- paddlex/configs/modules/layout_detection/PP-DocLayout_plus-L.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv5_mobile_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv5_server_det.yaml +40 -0
- paddlex/configs/modules/text_recognition/PP-OCRv5_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml +39 -0
- paddlex/configs/modules/textline_orientation/PP-LCNet_x1_0_textline_ori.yaml +41 -0
- paddlex/configs/pipelines/OCR.yaml +7 -6
- paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +3 -1
- paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +91 -34
- paddlex/configs/pipelines/PP-StructureV3.yaml +72 -72
- paddlex/configs/pipelines/doc_understanding.yaml +1 -1
- paddlex/configs/pipelines/formula_recognition.yaml +2 -2
- paddlex/configs/pipelines/layout_parsing.yaml +3 -2
- paddlex/configs/pipelines/seal_recognition.yaml +1 -0
- paddlex/configs/pipelines/table_recognition.yaml +2 -1
- paddlex/configs/pipelines/table_recognition_v2.yaml +7 -1
- paddlex/hpip_links.html +20 -20
- paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +33 -10
- paddlex/inference/common/batch_sampler/image_batch_sampler.py +34 -25
- paddlex/inference/common/result/mixin.py +19 -12
- paddlex/inference/models/base/predictor/base_predictor.py +2 -8
- paddlex/inference/models/common/static_infer.py +29 -73
- paddlex/inference/models/common/tokenizer/__init__.py +2 -0
- paddlex/inference/models/common/tokenizer/clip_tokenizer.py +1 -1
- paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +2 -2
- paddlex/inference/models/common/tokenizer/qwen2_5_tokenizer.py +112 -0
- paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +7 -1
- paddlex/inference/models/common/tokenizer/qwen_tokenizer.py +288 -0
- paddlex/inference/models/common/tokenizer/tokenizer_utils.py +13 -13
- paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +3 -3
- paddlex/inference/models/common/tokenizer/vocab.py +7 -7
- paddlex/inference/models/common/ts/funcs.py +19 -8
- paddlex/inference/models/common/vlm/conversion_utils.py +99 -0
- paddlex/inference/models/common/vlm/fusion_ops.py +205 -0
- paddlex/inference/models/common/vlm/generation/configuration_utils.py +1 -1
- paddlex/inference/models/common/vlm/generation/logits_process.py +1 -1
- paddlex/inference/models/common/vlm/generation/utils.py +1 -1
- paddlex/inference/models/common/vlm/transformers/configuration_utils.py +3 -3
- paddlex/inference/models/common/vlm/transformers/conversion_utils.py +3 -3
- paddlex/inference/models/common/vlm/transformers/model_outputs.py +2 -2
- paddlex/inference/models/common/vlm/transformers/model_utils.py +7 -31
- paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py +830 -0
- paddlex/inference/models/doc_vlm/modeling/__init__.py +2 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2.py +1606 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_5_vl.py +3006 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +0 -105
- paddlex/inference/models/doc_vlm/predictor.py +79 -24
- paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py +97 -0
- paddlex/inference/models/doc_vlm/processors/__init__.py +2 -0
- paddlex/inference/models/doc_vlm/processors/common.py +189 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_5_vl.py +548 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +21 -176
- paddlex/inference/models/formula_recognition/predictor.py +8 -2
- paddlex/inference/models/formula_recognition/processors.py +90 -77
- paddlex/inference/models/formula_recognition/result.py +28 -27
- paddlex/inference/models/image_feature/processors.py +3 -4
- paddlex/inference/models/keypoint_detection/predictor.py +3 -0
- paddlex/inference/models/object_detection/predictor.py +2 -0
- paddlex/inference/models/object_detection/processors.py +28 -3
- paddlex/inference/models/object_detection/utils.py +2 -0
- paddlex/inference/models/table_structure_recognition/result.py +0 -10
- paddlex/inference/models/text_detection/predictor.py +8 -0
- paddlex/inference/models/text_detection/processors.py +44 -10
- paddlex/inference/models/text_detection/result.py +0 -10
- paddlex/inference/models/text_recognition/result.py +1 -1
- paddlex/inference/pipelines/__init__.py +9 -5
- paddlex/inference/pipelines/_parallel.py +172 -0
- paddlex/inference/pipelines/anomaly_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/attribute_recognition/pipeline.py +11 -1
- paddlex/inference/pipelines/base.py +14 -4
- paddlex/inference/pipelines/components/faisser.py +1 -1
- paddlex/inference/pipelines/doc_preprocessor/pipeline.py +53 -27
- paddlex/inference/pipelines/formula_recognition/pipeline.py +120 -82
- paddlex/inference/pipelines/formula_recognition/result.py +1 -11
- paddlex/inference/pipelines/image_classification/pipeline.py +16 -6
- paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +16 -6
- paddlex/inference/pipelines/instance_segmentation/pipeline.py +16 -6
- paddlex/inference/pipelines/keypoint_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/layout_parsing/layout_objects.py +859 -0
- paddlex/inference/pipelines/layout_parsing/pipeline.py +34 -47
- paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +832 -260
- paddlex/inference/pipelines/layout_parsing/result.py +4 -17
- paddlex/inference/pipelines/layout_parsing/result_v2.py +259 -245
- paddlex/inference/pipelines/layout_parsing/setting.py +88 -0
- paddlex/inference/pipelines/layout_parsing/utils.py +391 -2028
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/__init__.py +16 -0
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +1199 -0
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +615 -0
- paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py +2 -2
- paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +2 -2
- paddlex/inference/pipelines/object_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/ocr/pipeline.py +127 -70
- paddlex/inference/pipelines/ocr/result.py +21 -18
- paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +2 -2
- paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +2 -2
- paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +2 -2
- paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +2 -5
- paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +6 -6
- paddlex/inference/pipelines/rotated_object_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/seal_recognition/pipeline.py +109 -53
- paddlex/inference/pipelines/semantic_segmentation/pipeline.py +16 -6
- paddlex/inference/pipelines/small_object_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/table_recognition/pipeline.py +26 -18
- paddlex/inference/pipelines/table_recognition/pipeline_v2.py +624 -53
- paddlex/inference/pipelines/table_recognition/result.py +1 -1
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +9 -5
- paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +2 -2
- paddlex/inference/pipelines/ts_classification/pipeline.py +2 -2
- paddlex/inference/pipelines/ts_forecasting/pipeline.py +2 -2
- paddlex/inference/pipelines/video_classification/pipeline.py +2 -2
- paddlex/inference/pipelines/video_detection/pipeline.py +2 -2
- paddlex/inference/serving/basic_serving/_app.py +46 -13
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +5 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +0 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +0 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +1 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +6 -2
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +1 -5
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +4 -5
- paddlex/inference/serving/infra/utils.py +20 -22
- paddlex/inference/serving/schemas/formula_recognition.py +1 -1
- paddlex/inference/serving/schemas/layout_parsing.py +1 -2
- paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +1 -2
- paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +2 -2
- paddlex/inference/serving/schemas/pp_structurev3.py +10 -6
- paddlex/inference/serving/schemas/seal_recognition.py +1 -1
- paddlex/inference/serving/schemas/table_recognition.py +2 -6
- paddlex/inference/serving/schemas/table_recognition_v2.py +5 -6
- paddlex/inference/utils/hpi.py +30 -16
- paddlex/inference/utils/hpi_model_info_collection.json +666 -162
- paddlex/inference/utils/io/readers.py +12 -12
- paddlex/inference/utils/misc.py +20 -0
- paddlex/inference/utils/mkldnn_blocklist.py +59 -0
- paddlex/inference/utils/official_models.py +140 -5
- paddlex/inference/utils/pp_option.py +74 -9
- paddlex/model.py +2 -2
- paddlex/modules/__init__.py +1 -1
- paddlex/modules/anomaly_detection/evaluator.py +2 -2
- paddlex/modules/base/__init__.py +1 -1
- paddlex/modules/base/evaluator.py +5 -5
- paddlex/modules/base/trainer.py +1 -1
- paddlex/modules/doc_vlm/dataset_checker.py +2 -2
- paddlex/modules/doc_vlm/evaluator.py +2 -2
- paddlex/modules/doc_vlm/exportor.py +2 -2
- paddlex/modules/doc_vlm/model_list.py +1 -1
- paddlex/modules/doc_vlm/trainer.py +2 -2
- paddlex/modules/face_recognition/evaluator.py +2 -2
- paddlex/modules/formula_recognition/evaluator.py +5 -2
- paddlex/modules/formula_recognition/model_list.py +3 -0
- paddlex/modules/formula_recognition/trainer.py +3 -0
- paddlex/modules/general_recognition/evaluator.py +1 -1
- paddlex/modules/image_classification/evaluator.py +2 -2
- paddlex/modules/image_classification/model_list.py +1 -0
- paddlex/modules/instance_segmentation/evaluator.py +1 -1
- paddlex/modules/keypoint_detection/evaluator.py +1 -1
- paddlex/modules/m_3d_bev_detection/evaluator.py +2 -2
- paddlex/modules/multilabel_classification/evaluator.py +2 -2
- paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +4 -4
- paddlex/modules/object_detection/evaluator.py +2 -2
- paddlex/modules/object_detection/model_list.py +2 -0
- paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +12 -2
- paddlex/modules/semantic_segmentation/evaluator.py +2 -2
- paddlex/modules/table_recognition/evaluator.py +2 -2
- paddlex/modules/text_detection/evaluator.py +2 -2
- paddlex/modules/text_detection/model_list.py +2 -0
- paddlex/modules/text_recognition/evaluator.py +2 -2
- paddlex/modules/text_recognition/model_list.py +2 -0
- paddlex/modules/ts_anomaly_detection/evaluator.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/ts_classification/evaluator.py +2 -2
- paddlex/modules/ts_forecast/evaluator.py +2 -2
- paddlex/modules/video_classification/evaluator.py +2 -2
- paddlex/modules/video_detection/evaluator.py +2 -2
- paddlex/ops/__init__.py +8 -5
- paddlex/paddlex_cli.py +19 -13
- paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +2 -2
- paddlex/repo_apis/PaddleClas_api/cls/config.py +1 -1
- paddlex/repo_apis/PaddleClas_api/cls/model.py +1 -1
- paddlex/repo_apis/PaddleClas_api/cls/register.py +10 -0
- paddlex/repo_apis/PaddleClas_api/cls/runner.py +1 -1
- paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +1 -1
- paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +1 -1
- paddlex/repo_apis/PaddleDetection_api/object_det/config.py +1 -1
- paddlex/repo_apis/PaddleDetection_api/object_det/model.py +1 -1
- paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +25 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/register.py +30 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +5 -9
- paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +27 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/model.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/register.py +18 -0
- paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +5 -9
- paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +18 -0
- paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/seg/model.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/seg/runner.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +3 -3
- paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +4 -4
- paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/video_det/config.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/video_det/model.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +1 -1
- paddlex/repo_apis/base/config.py +1 -1
- paddlex/repo_manager/core.py +3 -3
- paddlex/repo_manager/meta.py +6 -2
- paddlex/repo_manager/repo.py +17 -16
- paddlex/utils/custom_device_list.py +26 -2
- paddlex/utils/deps.py +3 -3
- paddlex/utils/device.py +5 -13
- paddlex/utils/env.py +4 -0
- paddlex/utils/flags.py +11 -4
- paddlex/utils/fonts/__init__.py +34 -4
- paddlex/utils/misc.py +1 -1
- paddlex/utils/subclass_register.py +2 -2
- {paddlex-3.0.0rc1.dist-info → paddlex-3.0.2.dist-info}/METADATA +349 -208
- {paddlex-3.0.0rc1.dist-info → paddlex-3.0.2.dist-info}/RECORD +240 -211
- {paddlex-3.0.0rc1.dist-info → paddlex-3.0.2.dist-info}/WHEEL +1 -1
- {paddlex-3.0.0rc1.dist-info → paddlex-3.0.2.dist-info}/entry_points.txt +1 -0
- {paddlex-3.0.0rc1.dist-info/licenses → paddlex-3.0.2.dist-info}/LICENSE +0 -0
- {paddlex-3.0.0rc1.dist-info → paddlex-3.0.2.dist-info}/top_level.txt +0 -0
@@ -474,6 +474,8 @@ def restructured_boxes(
|
|
474
474
|
ymin = max(0, ymin)
|
475
475
|
xmax = min(w, xmax)
|
476
476
|
ymax = min(h, ymax)
|
477
|
+
if xmax <= xmin or ymax <= ymin:
|
478
|
+
continue
|
477
479
|
box_list.append(
|
478
480
|
{
|
479
481
|
"cls_id": int(box[0]),
|
@@ -744,11 +746,34 @@ class DetPostProcess:
|
|
744
746
|
)
|
745
747
|
|
746
748
|
if layout_nms:
|
747
|
-
pass
|
748
|
-
### Layout postprocess for NMS
|
749
749
|
selected_indices = nms(boxes, iou_same=0.6, iou_diff=0.98)
|
750
750
|
boxes = np.array(boxes[selected_indices])
|
751
751
|
|
752
|
+
filter_large_image = True
|
753
|
+
if filter_large_image and len(boxes) > 1 and boxes.shape[1] == 6:
|
754
|
+
if img_size[0] > img_size[1]:
|
755
|
+
area_thres = 0.82
|
756
|
+
else:
|
757
|
+
area_thres = 0.93
|
758
|
+
image_index = self.labels.index("image") if "image" in self.labels else None
|
759
|
+
img_area = img_size[0] * img_size[1]
|
760
|
+
filtered_boxes = []
|
761
|
+
for box in boxes:
|
762
|
+
label_index, score, xmin, ymin, xmax, ymax = box
|
763
|
+
if label_index == image_index:
|
764
|
+
xmin = max(0, xmin)
|
765
|
+
ymin = max(0, ymin)
|
766
|
+
xmax = min(img_size[0], xmax)
|
767
|
+
ymax = min(img_size[1], ymax)
|
768
|
+
box_area = (xmax - xmin) * (ymax - ymin)
|
769
|
+
if box_area <= area_thres * img_area:
|
770
|
+
filtered_boxes.append(box)
|
771
|
+
else:
|
772
|
+
filtered_boxes.append(box)
|
773
|
+
if len(filtered_boxes) == 0:
|
774
|
+
filtered_boxes = boxes
|
775
|
+
boxes = np.array(filtered_boxes)
|
776
|
+
|
752
777
|
if layout_merge_bboxes_mode:
|
753
778
|
formula_index = (
|
754
779
|
self.labels.index("formula") if "formula" in self.labels else None
|
@@ -798,7 +823,7 @@ class DetPostProcess:
|
|
798
823
|
boxes = boxes[keep_mask]
|
799
824
|
|
800
825
|
if boxes.size == 0:
|
801
|
-
return
|
826
|
+
return []
|
802
827
|
|
803
828
|
if layout_unclip_ratio:
|
804
829
|
if isinstance(layout_unclip_ratio, float):
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import copy
|
16
|
-
from pathlib import Path
|
17
16
|
|
18
17
|
import numpy as np
|
19
18
|
|
@@ -28,15 +27,6 @@ if is_dep_available("opencv-contrib-python"):
|
|
28
27
|
class TableRecResult(BaseCVResult):
|
29
28
|
"""SaveTableResults"""
|
30
29
|
|
31
|
-
def _get_input_fn(self):
|
32
|
-
fn = super()._get_input_fn()
|
33
|
-
if (page_idx := self["page_index"]) is not None:
|
34
|
-
fp = Path(fn)
|
35
|
-
stem, suffix = fp.stem, fp.suffix
|
36
|
-
return f"{stem}_{page_idx}{suffix}"
|
37
|
-
else:
|
38
|
-
return fn
|
39
|
-
|
40
30
|
def _to_img(self):
|
41
31
|
image = self["input_img"]
|
42
32
|
bbox_res = self["bbox"]
|
@@ -41,6 +41,7 @@ class TextDetPredictor(BasePredictor):
|
|
41
41
|
box_thresh: Union[float, None] = None,
|
42
42
|
unclip_ratio: Union[float, None] = None,
|
43
43
|
input_shape=None,
|
44
|
+
max_side_limit: int = 4000,
|
44
45
|
*args,
|
45
46
|
**kwargs
|
46
47
|
):
|
@@ -52,6 +53,7 @@ class TextDetPredictor(BasePredictor):
|
|
52
53
|
self.box_thresh = box_thresh
|
53
54
|
self.unclip_ratio = unclip_ratio
|
54
55
|
self.input_shape = input_shape
|
56
|
+
self.max_side_limit = max_side_limit
|
55
57
|
self.pre_tfs, self.infer, self.post_op = self._build()
|
56
58
|
|
57
59
|
def _build_batch_sampler(self):
|
@@ -85,6 +87,7 @@ class TextDetPredictor(BasePredictor):
|
|
85
87
|
thresh: Union[float, None] = None,
|
86
88
|
box_thresh: Union[float, None] = None,
|
87
89
|
unclip_ratio: Union[float, None] = None,
|
90
|
+
max_side_limit: Union[int, None] = None,
|
88
91
|
):
|
89
92
|
|
90
93
|
batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data.instances)
|
@@ -92,6 +95,9 @@ class TextDetPredictor(BasePredictor):
|
|
92
95
|
imgs=batch_raw_imgs,
|
93
96
|
limit_side_len=limit_side_len or self.limit_side_len,
|
94
97
|
limit_type=limit_type or self.limit_type,
|
98
|
+
max_side_limit=(
|
99
|
+
max_side_limit if max_side_limit is not None else self.max_side_limit
|
100
|
+
),
|
95
101
|
)
|
96
102
|
batch_imgs = self.pre_tfs["Normalize"](imgs=batch_imgs)
|
97
103
|
batch_imgs = self.pre_tfs["ToCHW"](imgs=batch_imgs)
|
@@ -127,6 +133,8 @@ class TextDetPredictor(BasePredictor):
|
|
127
133
|
# TODO: align to PaddleOCR
|
128
134
|
|
129
135
|
if self.model_name in (
|
136
|
+
"PP-OCRv5_server_det",
|
137
|
+
"PP-OCRv5_mobile_det",
|
130
138
|
"PP-OCRv4_server_det",
|
131
139
|
"PP-OCRv4_mobile_det",
|
132
140
|
"PP-OCRv3_server_det",
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
-
import sys
|
17
16
|
from typing import Union
|
18
17
|
|
19
18
|
import numpy as np
|
@@ -33,8 +32,7 @@ if is_dep_available("pyclipper"):
|
|
33
32
|
class DetResizeForTest:
|
34
33
|
"""DetResizeForTest"""
|
35
34
|
|
36
|
-
def __init__(self, input_shape=None, **kwargs):
|
37
|
-
super().__init__()
|
35
|
+
def __init__(self, input_shape=None, max_side_limit=4000, **kwargs):
|
38
36
|
self.resize_type = 0
|
39
37
|
self.keep_ratio = False
|
40
38
|
if input_shape is not None:
|
@@ -55,22 +53,34 @@ class DetResizeForTest:
|
|
55
53
|
self.limit_side_len = 736
|
56
54
|
self.limit_type = "min"
|
57
55
|
|
56
|
+
self.max_side_limit = max_side_limit
|
57
|
+
|
58
58
|
def __call__(
|
59
59
|
self,
|
60
60
|
imgs,
|
61
61
|
limit_side_len: Union[int, None] = None,
|
62
62
|
limit_type: Union[str, None] = None,
|
63
|
+
max_side_limit: Union[int, None] = None,
|
63
64
|
):
|
64
65
|
"""apply"""
|
66
|
+
max_side_limit = (
|
67
|
+
max_side_limit if max_side_limit is not None else self.max_side_limit
|
68
|
+
)
|
65
69
|
resize_imgs, img_shapes = [], []
|
66
70
|
for ori_img in imgs:
|
67
|
-
img, shape = self.resize(
|
71
|
+
img, shape = self.resize(
|
72
|
+
ori_img, limit_side_len, limit_type, max_side_limit
|
73
|
+
)
|
68
74
|
resize_imgs.append(img)
|
69
75
|
img_shapes.append(shape)
|
70
76
|
return resize_imgs, img_shapes
|
71
77
|
|
72
78
|
def resize(
|
73
|
-
self,
|
79
|
+
self,
|
80
|
+
img,
|
81
|
+
limit_side_len: Union[int, None],
|
82
|
+
limit_type: Union[str, None],
|
83
|
+
max_side_limit: Union[int, None] = None,
|
74
84
|
):
|
75
85
|
src_h, src_w, _ = img.shape
|
76
86
|
if sum([src_h, src_w]) < 64:
|
@@ -79,7 +89,7 @@ class DetResizeForTest:
|
|
79
89
|
if self.resize_type == 0:
|
80
90
|
# img, shape = self.resize_image_type0(img)
|
81
91
|
img, [ratio_h, ratio_w] = self.resize_image_type0(
|
82
|
-
img, limit_side_len, limit_type
|
92
|
+
img, limit_side_len, limit_type, max_side_limit
|
83
93
|
)
|
84
94
|
elif self.resize_type == 2:
|
85
95
|
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
@@ -105,6 +115,8 @@ class DetResizeForTest:
|
|
105
115
|
resize_w = ori_w * resize_h / ori_h
|
106
116
|
N = math.ceil(resize_w / 32)
|
107
117
|
resize_w = N * 32
|
118
|
+
if resize_h == ori_h and resize_w == ori_w:
|
119
|
+
return img, [1.0, 1.0]
|
108
120
|
ratio_h = float(resize_h) / ori_h
|
109
121
|
ratio_w = float(resize_w) / ori_w
|
110
122
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
@@ -112,7 +124,11 @@ class DetResizeForTest:
|
|
112
124
|
return img, [ratio_h, ratio_w]
|
113
125
|
|
114
126
|
def resize_image_type0(
|
115
|
-
self,
|
127
|
+
self,
|
128
|
+
img,
|
129
|
+
limit_side_len: Union[int, None],
|
130
|
+
limit_type: Union[str, None],
|
131
|
+
max_side_limit: Union[int, None] = None,
|
116
132
|
):
|
117
133
|
"""
|
118
134
|
resize image to a size multiple of 32 which is required by the network
|
@@ -149,16 +165,28 @@ class DetResizeForTest:
|
|
149
165
|
resize_h = int(h * ratio)
|
150
166
|
resize_w = int(w * ratio)
|
151
167
|
|
168
|
+
if max(resize_h, resize_w) > max_side_limit:
|
169
|
+
logging.warning(
|
170
|
+
f"Resized image size ({resize_h}x{resize_w}) exceeds max_side_limit of {max_side_limit}. "
|
171
|
+
f"Resizing to fit within limit."
|
172
|
+
)
|
173
|
+
ratio = float(max_side_limit) / max(resize_h, resize_w)
|
174
|
+
resize_h, resize_w = int(resize_h * ratio), int(resize_w * ratio)
|
175
|
+
|
152
176
|
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
153
177
|
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
154
178
|
|
179
|
+
if resize_h == h and resize_w == w:
|
180
|
+
return img, [1.0, 1.0]
|
181
|
+
|
155
182
|
try:
|
156
183
|
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
157
184
|
return None, (None, None)
|
158
185
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
159
186
|
except:
|
160
187
|
logging.info(img.shape, resize_w, resize_h)
|
161
|
-
|
188
|
+
raise
|
189
|
+
|
162
190
|
ratio_h = resize_h / float(h)
|
163
191
|
ratio_w = resize_w / float(w)
|
164
192
|
return img, [ratio_h, ratio_w]
|
@@ -181,6 +209,10 @@ class DetResizeForTest:
|
|
181
209
|
max_stride = 128
|
182
210
|
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
183
211
|
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
212
|
+
|
213
|
+
if resize_h == h and resize_w == w:
|
214
|
+
return img, [1.0, 1.0]
|
215
|
+
|
184
216
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
185
217
|
ratio_h = resize_h / float(h)
|
186
218
|
ratio_w = resize_w / float(w)
|
@@ -191,6 +223,8 @@ class DetResizeForTest:
|
|
191
223
|
"""resize the image"""
|
192
224
|
resize_c, resize_h, resize_w = self.input_shape # (c, h, w)
|
193
225
|
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
226
|
+
if resize_h == ori_h and resize_w == ori_w:
|
227
|
+
return img, [1.0, 1.0]
|
194
228
|
ratio_h = float(resize_h) / ori_h
|
195
229
|
ratio_w = float(resize_w) / ori_w
|
196
230
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
@@ -200,7 +234,7 @@ class DetResizeForTest:
|
|
200
234
|
@benchmark.timeit
|
201
235
|
@class_requires_deps("opencv-contrib-python")
|
202
236
|
class NormalizeImage:
|
203
|
-
"""normalize image such as
|
237
|
+
"""normalize image such as subtract mean, divide std"""
|
204
238
|
|
205
239
|
def __init__(self, scale=None, mean=None, std=None, order="chw"):
|
206
240
|
super().__init__()
|
@@ -253,7 +287,7 @@ class DBPostProcess:
|
|
253
287
|
use_dilation=False,
|
254
288
|
score_mode="fast",
|
255
289
|
box_type="quad",
|
256
|
-
**kwargs
|
290
|
+
**kwargs,
|
257
291
|
):
|
258
292
|
super().__init__()
|
259
293
|
self.thresh = thresh
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import copy
|
16
|
-
from pathlib import Path
|
17
16
|
|
18
17
|
import numpy as np
|
19
18
|
|
@@ -27,15 +26,6 @@ if is_dep_available("opencv-contrib-python"):
|
|
27
26
|
@class_requires_deps("opencv-contrib-python")
|
28
27
|
class TextDetResult(BaseCVResult):
|
29
28
|
|
30
|
-
def _get_input_fn(self):
|
31
|
-
fn = super()._get_input_fn()
|
32
|
-
if (page_idx := self["page_index"]) is not None:
|
33
|
-
fp = Path(fn)
|
34
|
-
stem, suffix = fp.stem, fp.suffix
|
35
|
-
return f"{stem}_{page_idx}{suffix}"
|
36
|
-
else:
|
37
|
-
return fn
|
38
|
-
|
39
29
|
def _to_img(self):
|
40
30
|
"""draw rectangle"""
|
41
31
|
boxes = self["dt_polys"]
|
@@ -35,7 +35,7 @@ class TextRecResult(BaseCVResult):
|
|
35
35
|
|
36
36
|
def _to_img(self):
|
37
37
|
"""Draw label on image"""
|
38
|
-
image = Image.fromarray(self["input_img"])
|
38
|
+
image = Image.fromarray(self["input_img"][:, :, ::-1])
|
39
39
|
rec_text = self["rec_text"]
|
40
40
|
rec_score = self["rec_score"]
|
41
41
|
image = image.convert("RGB")
|
@@ -126,7 +126,8 @@ def create_pipeline(
|
|
126
126
|
pp_option (Optional[PaddlePredictorOption], optional): The options for
|
127
127
|
the PaddlePredictor. Defaults to None.
|
128
128
|
use_hpip (Optional[bool], optional): Whether to use the high-performance
|
129
|
-
inference plugin (HPIP).
|
129
|
+
inference plugin (HPIP). If set to None, the setting from the
|
130
|
+
configuration file or `config` will be used. Defaults to None.
|
130
131
|
hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional): The
|
131
132
|
high-performance inference configuration dictionary.
|
132
133
|
Defaults to None.
|
@@ -150,13 +151,16 @@ def create_pipeline(
|
|
150
151
|
pipeline,
|
151
152
|
config["pipeline_name"],
|
152
153
|
)
|
154
|
+
config = config.copy()
|
153
155
|
pipeline_name = config["pipeline_name"]
|
154
|
-
if device is None:
|
155
|
-
device = config.get("device", None)
|
156
156
|
if use_hpip is None:
|
157
|
-
use_hpip = config.
|
157
|
+
use_hpip = config.pop("use_hpip", False)
|
158
|
+
else:
|
159
|
+
config.pop("use_hpip", None)
|
158
160
|
if hpi_config is None:
|
159
|
-
hpi_config = config.
|
161
|
+
hpi_config = config.pop("hpi_config", None)
|
162
|
+
else:
|
163
|
+
config.pop("hpi_config", None)
|
160
164
|
|
161
165
|
pipeline = BasePipeline.get(pipeline_name)(
|
162
166
|
config=config,
|
@@ -0,0 +1,172 @@
|
|
1
|
+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import abc
|
16
|
+
from concurrent.futures import ThreadPoolExecutor
|
17
|
+
|
18
|
+
from ...utils import device as device_utils
|
19
|
+
from ..common.batch_sampler import ImageBatchSampler
|
20
|
+
from .base import BasePipeline
|
21
|
+
|
22
|
+
|
23
|
+
class MultiDeviceSimpleInferenceExecutor(object):
|
24
|
+
def __init__(self, pipelines, batch_sampler, *, postprocess_result=None):
|
25
|
+
super().__init__()
|
26
|
+
self._pipelines = pipelines
|
27
|
+
self._batch_sampler = batch_sampler
|
28
|
+
self._postprocess_result = postprocess_result
|
29
|
+
|
30
|
+
@property
|
31
|
+
def pipelines(self):
|
32
|
+
return self._pipelines
|
33
|
+
|
34
|
+
def execute(
|
35
|
+
self,
|
36
|
+
input,
|
37
|
+
*args,
|
38
|
+
**kwargs,
|
39
|
+
):
|
40
|
+
with ThreadPoolExecutor(max_workers=len(self._pipelines)) as pool:
|
41
|
+
input_batches = self._batch_sampler(input)
|
42
|
+
out_of_data = False
|
43
|
+
while not out_of_data:
|
44
|
+
input_future_pairs = []
|
45
|
+
for pipeline in self._pipelines:
|
46
|
+
try:
|
47
|
+
input_batch = next(input_batches)
|
48
|
+
except StopIteration:
|
49
|
+
out_of_data = True
|
50
|
+
break
|
51
|
+
input_instances = input_batch.instances
|
52
|
+
future = pool.submit(
|
53
|
+
lambda pipeline, input_instances, args, kwargs: list(
|
54
|
+
pipeline.predict(input_instances, *args, **kwargs)
|
55
|
+
),
|
56
|
+
pipeline,
|
57
|
+
input_instances,
|
58
|
+
args,
|
59
|
+
kwargs,
|
60
|
+
)
|
61
|
+
input_future_pairs.append((input_batch, future))
|
62
|
+
|
63
|
+
# We synchronize here to keep things simple (no data
|
64
|
+
# prefetching, no queues, no dedicated workers), although
|
65
|
+
# it's less efficient.
|
66
|
+
for input_batch, future in input_future_pairs:
|
67
|
+
result = future.result()
|
68
|
+
for input_path, result_item in zip(input_batch.input_paths, result):
|
69
|
+
result_item["input_path"] = input_path
|
70
|
+
if self._postprocess_result:
|
71
|
+
result = self._postprocess_result(result, input_batch)
|
72
|
+
yield from result
|
73
|
+
|
74
|
+
|
75
|
+
class AutoParallelSimpleInferencePipeline(BasePipeline):
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
config,
|
79
|
+
*args,
|
80
|
+
**kwargs,
|
81
|
+
):
|
82
|
+
super().__init__(*args, **kwargs)
|
83
|
+
|
84
|
+
self._multi_device_inference = False
|
85
|
+
if self.device is not None:
|
86
|
+
device_type, device_ids = device_utils.parse_device(self.device)
|
87
|
+
if device_ids is not None and len(device_ids) > 1:
|
88
|
+
self._multi_device_inference = True
|
89
|
+
self._pipelines = []
|
90
|
+
for device_id in device_ids:
|
91
|
+
pipeline = self._create_internal_pipeline(
|
92
|
+
config, device_utils.constr_device(device_type, [device_id])
|
93
|
+
)
|
94
|
+
self._pipelines.append(pipeline)
|
95
|
+
batch_size = self._get_batch_size(config)
|
96
|
+
batch_sampler = self._create_batch_sampler(batch_size)
|
97
|
+
self._executor = MultiDeviceSimpleInferenceExecutor(
|
98
|
+
self._pipelines,
|
99
|
+
batch_sampler,
|
100
|
+
postprocess_result=self._postprocess_result,
|
101
|
+
)
|
102
|
+
if not self._multi_device_inference:
|
103
|
+
self._pipeline = self._create_internal_pipeline(config, self.device)
|
104
|
+
|
105
|
+
@property
|
106
|
+
def multi_device_inference(self):
|
107
|
+
return self._multi_device_inference
|
108
|
+
|
109
|
+
def __getattr__(self, name):
|
110
|
+
if self._multi_device_inference:
|
111
|
+
first_pipeline = self._executor.pipelines[0]
|
112
|
+
return getattr(first_pipeline, name)
|
113
|
+
else:
|
114
|
+
return getattr(self._pipeline, name)
|
115
|
+
|
116
|
+
def predict(
|
117
|
+
self,
|
118
|
+
input,
|
119
|
+
*args,
|
120
|
+
**kwargs,
|
121
|
+
):
|
122
|
+
if self._multi_device_inference:
|
123
|
+
yield from self._executor.execute(
|
124
|
+
input,
|
125
|
+
*args,
|
126
|
+
**kwargs,
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
yield from self._pipeline.predict(
|
130
|
+
input,
|
131
|
+
*args,
|
132
|
+
**kwargs,
|
133
|
+
)
|
134
|
+
|
135
|
+
@abc.abstractmethod
|
136
|
+
def _create_internal_pipeline(self, config, device):
|
137
|
+
raise NotImplementedError
|
138
|
+
|
139
|
+
@abc.abstractmethod
|
140
|
+
def _get_batch_size(self, config):
|
141
|
+
raise NotImplementedError
|
142
|
+
|
143
|
+
@abc.abstractmethod
|
144
|
+
def _create_batch_sampler(self, batch_size):
|
145
|
+
raise NotImplementedError
|
146
|
+
|
147
|
+
def _postprocess_result(self, result, input_batch):
|
148
|
+
return result
|
149
|
+
|
150
|
+
|
151
|
+
class AutoParallelImageSimpleInferencePipeline(AutoParallelSimpleInferencePipeline):
|
152
|
+
@property
|
153
|
+
@abc.abstractmethod
|
154
|
+
def _pipeline_cls(self):
|
155
|
+
raise NotImplementedError
|
156
|
+
|
157
|
+
def _create_internal_pipeline(self, config, device):
|
158
|
+
return self._pipeline_cls(
|
159
|
+
config,
|
160
|
+
device=device,
|
161
|
+
pp_option=self.pp_option,
|
162
|
+
use_hpip=self.use_hpip,
|
163
|
+
hpi_config=self.hpi_config,
|
164
|
+
)
|
165
|
+
|
166
|
+
def _create_batch_sampler(self, batch_size):
|
167
|
+
return ImageBatchSampler(batch_size)
|
168
|
+
|
169
|
+
def _postprocess_result(self, result, input_batch):
|
170
|
+
for page_index, item in zip(input_batch.page_indexes, result):
|
171
|
+
item["page_index"] = page_index
|
172
|
+
return result
|
@@ -20,15 +20,13 @@ from ....utils.deps import pipeline_requires_extra
|
|
20
20
|
from ...models.anomaly_detection.result import UadResult
|
21
21
|
from ...utils.hpi import HPIConfig
|
22
22
|
from ...utils.pp_option import PaddlePredictorOption
|
23
|
+
from .._parallel import AutoParallelImageSimpleInferencePipeline
|
23
24
|
from ..base import BasePipeline
|
24
25
|
|
25
26
|
|
26
|
-
|
27
|
-
class AnomalyDetectionPipeline(BasePipeline):
|
27
|
+
class _AnomalyDetectionPipeline(BasePipeline):
|
28
28
|
"""Image AnomalyDetectionPipeline Pipeline"""
|
29
29
|
|
30
|
-
entities = "anomaly_detection"
|
31
|
-
|
32
30
|
def __init__(
|
33
31
|
self,
|
34
32
|
config: Dict,
|
@@ -44,9 +42,9 @@ class AnomalyDetectionPipeline(BasePipeline):
|
|
44
42
|
device (str, optional): Device to run the predictions on. Defaults to None.
|
45
43
|
pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
|
46
44
|
use_hpip (bool, optional): Whether to use the high-performance
|
47
|
-
inference plugin (HPIP). Defaults to False.
|
45
|
+
inference plugin (HPIP) by default. Defaults to False.
|
48
46
|
hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
|
49
|
-
The high-performance inference configuration dictionary.
|
47
|
+
The default high-performance inference configuration dictionary.
|
50
48
|
Defaults to None.
|
51
49
|
"""
|
52
50
|
|
@@ -70,3 +68,15 @@ class AnomalyDetectionPipeline(BasePipeline):
|
|
70
68
|
UadResult: The predicted anomaly results.
|
71
69
|
"""
|
72
70
|
yield from self.anomaly_detetion_model(input)
|
71
|
+
|
72
|
+
|
73
|
+
@pipeline_requires_extra("cv")
|
74
|
+
class AnomalyDetectionPipeline(AutoParallelImageSimpleInferencePipeline):
|
75
|
+
entities = "anomaly_detection"
|
76
|
+
|
77
|
+
@property
|
78
|
+
def _pipeline_cls(self):
|
79
|
+
return _AnomalyDetectionPipeline
|
80
|
+
|
81
|
+
def _get_batch_size(self, config):
|
82
|
+
return config["SubModules"]["AnomalyDetection"].get("batch_size", 1)
|
@@ -21,12 +21,13 @@ from ...common.batch_sampler import ImageBatchSampler
|
|
21
21
|
from ...common.reader import ReadImage
|
22
22
|
from ...utils.hpi import HPIConfig
|
23
23
|
from ...utils.pp_option import PaddlePredictorOption
|
24
|
+
from .._parallel import AutoParallelImageSimpleInferencePipeline
|
24
25
|
from ..base import BasePipeline
|
25
26
|
from ..components import CropByBoxes
|
26
27
|
from .result import AttributeRecResult
|
27
28
|
|
28
29
|
|
29
|
-
class
|
30
|
+
class _AttributeRecPipeline(BasePipeline):
|
30
31
|
"""Attribute Rec Pipeline"""
|
31
32
|
|
32
33
|
def __init__(
|
@@ -100,6 +101,15 @@ class AttributeRecPipeline(BasePipeline):
|
|
100
101
|
return AttributeRecResult(single_img_res)
|
101
102
|
|
102
103
|
|
104
|
+
class AttributeRecPipeline(AutoParallelImageSimpleInferencePipeline):
|
105
|
+
@property
|
106
|
+
def _pipeline_cls(self):
|
107
|
+
return _AttributeRecPipeline
|
108
|
+
|
109
|
+
def _get_batch_size(self, config):
|
110
|
+
return config["SubModules"]["Detection"]["batch_size"]
|
111
|
+
|
112
|
+
|
103
113
|
@pipeline_requires_extra("cv")
|
104
114
|
class PedestrianAttributeRecPipeline(AttributeRecPipeline):
|
105
115
|
entities = "pedestrian_attribute_recognition"
|
@@ -48,9 +48,9 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
48
48
|
device (str, optional): The device to use for prediction. Defaults to None.
|
49
49
|
pp_option (PaddlePredictorOption, optional): The options for PaddlePredictor. Defaults to None.
|
50
50
|
use_hpip (bool, optional): Whether to use the high-performance
|
51
|
-
inference plugin (HPIP). Defaults to False.
|
51
|
+
inference plugin (HPIP) by default. Defaults to False.
|
52
52
|
hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
|
53
|
-
The high-performance inference configuration dictionary.
|
53
|
+
The default high-performance inference configuration dictionary.
|
54
54
|
Defaults to None.
|
55
55
|
"""
|
56
56
|
super().__init__()
|
@@ -96,12 +96,20 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
96
96
|
|
97
97
|
logging.info("Creating model: %s", (config["model_name"], model_dir))
|
98
98
|
|
99
|
+
# TODO(gaotingquan): support to specify pp_option by model in pipeline
|
100
|
+
if self.pp_option is not None:
|
101
|
+
pp_option = self.pp_option.copy()
|
102
|
+
pp_option.model_name = config["model_name"]
|
103
|
+
pp_option.run_mode = self.pp_option.run_mode
|
104
|
+
else:
|
105
|
+
pp_option = None
|
106
|
+
|
99
107
|
model = create_predictor(
|
100
108
|
model_name=config["model_name"],
|
101
109
|
model_dir=model_dir,
|
102
110
|
device=self.device,
|
103
111
|
batch_size=config.get("batch_size", 1),
|
104
|
-
pp_option=
|
112
|
+
pp_option=pp_option,
|
105
113
|
use_hpip=use_hpip,
|
106
114
|
hpi_config=hpi_config,
|
107
115
|
**kwargs,
|
@@ -132,7 +140,9 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
132
140
|
pipeline = create_pipeline(
|
133
141
|
config=config,
|
134
142
|
device=self.device,
|
135
|
-
pp_option=
|
143
|
+
pp_option=(
|
144
|
+
self.pp_option.copy() if self.pp_option is not None else self.pp_option
|
145
|
+
),
|
136
146
|
use_hpip=use_hpip,
|
137
147
|
hpi_config=hpi_config,
|
138
148
|
)
|
@@ -178,7 +178,7 @@ class FaissBuilder:
|
|
178
178
|
|
179
179
|
@classmethod
|
180
180
|
def _get_index_type(cls, metric_type, index_type, num=None):
|
181
|
-
# if IVF method, cal ivf number
|
181
|
+
# if IVF method, cal ivf number automatically
|
182
182
|
if index_type == "IVF":
|
183
183
|
index_type = index_type + str(min(int(num // 8), 65536))
|
184
184
|
if metric_type in cls.BINARY_METRIC_TYPE:
|