paddlex 3.0.1__py3-none-any.whl → 3.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paddlex/.version +1 -1
- paddlex/inference/models/common/static_infer.py +18 -14
- paddlex/inference/models/common/ts/funcs.py +19 -8
- paddlex/inference/models/formula_recognition/predictor.py +1 -1
- paddlex/inference/models/formula_recognition/processors.py +2 -2
- paddlex/inference/models/text_recognition/result.py +1 -1
- paddlex/inference/pipelines/layout_parsing/layout_objects.py +859 -0
- paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +144 -205
- paddlex/inference/pipelines/layout_parsing/result_v2.py +6 -270
- paddlex/inference/pipelines/layout_parsing/setting.py +1 -0
- paddlex/inference/pipelines/layout_parsing/utils.py +108 -312
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +302 -247
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +156 -104
- paddlex/inference/pipelines/ocr/result.py +2 -2
- paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +1 -1
- paddlex/inference/serving/basic_serving/_app.py +46 -13
- paddlex/inference/utils/hpi.py +23 -16
- paddlex/inference/utils/hpi_model_info_collection.json +627 -202
- paddlex/inference/utils/misc.py +20 -0
- paddlex/inference/utils/mkldnn_blocklist.py +36 -2
- paddlex/inference/utils/official_models.py +126 -5
- paddlex/inference/utils/pp_option.py +48 -4
- paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +12 -2
- paddlex/ops/__init__.py +6 -3
- paddlex/utils/deps.py +2 -2
- paddlex/utils/device.py +4 -19
- paddlex/utils/flags.py +9 -0
- paddlex/utils/subclass_register.py +2 -2
- {paddlex-3.0.1.dist-info → paddlex-3.0.2.dist-info}/METADATA +307 -162
- {paddlex-3.0.1.dist-info → paddlex-3.0.2.dist-info}/RECORD +34 -32
- {paddlex-3.0.1.dist-info → paddlex-3.0.2.dist-info}/WHEEL +1 -1
- {paddlex-3.0.1.dist-info → paddlex-3.0.2.dist-info}/entry_points.txt +1 -0
- {paddlex-3.0.1.dist-info/licenses → paddlex-3.0.2.dist-info}/LICENSE +0 -0
- {paddlex-3.0.1.dist-info → paddlex-3.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
def is_mkldnn_available():
|
17
|
+
# XXX: Not sure if this is the best way to check if MKL-DNN is available
|
18
|
+
from paddle.inference import Config
|
19
|
+
|
20
|
+
return hasattr(Config, "set_mkldnn_cache_capacity")
|
@@ -13,12 +13,46 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
MKLDNN_BLOCKLIST = [
|
16
|
-
"SLANeXt_wired",
|
17
|
-
"SLANeXt_wireless",
|
18
16
|
"LaTeX_OCR_rec",
|
19
17
|
"PP-FormulaNet-L",
|
20
18
|
"PP-FormulaNet-S",
|
21
19
|
"UniMERNet",
|
20
|
+
"UVDoc",
|
21
|
+
"Cascade-MaskRCNN-ResNet50-FPN",
|
22
|
+
"Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
|
23
|
+
"Mask-RT-DETR-M",
|
24
|
+
"Mask-RT-DETR-S",
|
25
|
+
"MaskRCNN-ResNeXt101-vd-FPN",
|
26
|
+
"MaskRCNN-ResNet101-FPN",
|
27
|
+
"MaskRCNN-ResNet101-vd-FPN",
|
28
|
+
"MaskRCNN-ResNet50-FPN",
|
29
|
+
"MaskRCNN-ResNet50-vd-FPN",
|
30
|
+
"MaskRCNN-ResNet50",
|
31
|
+
"SOLOv2",
|
32
|
+
"PP-TinyPose_128x96",
|
33
|
+
"PP-TinyPose_256x192",
|
34
|
+
"Cascade-FasterRCNN-ResNet50-FPN",
|
35
|
+
"Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN",
|
36
|
+
"Co-DINO-Swin-L",
|
37
|
+
"Co-Deformable-DETR-Swin-T",
|
38
|
+
"FasterRCNN-ResNeXt101-vd-FPN",
|
39
|
+
"FasterRCNN-ResNet101-FPN",
|
40
|
+
"FasterRCNN-ResNet101",
|
41
|
+
"FasterRCNN-ResNet34-FPN",
|
42
|
+
"FasterRCNN-ResNet50-FPN",
|
43
|
+
"FasterRCNN-ResNet50-vd-FPN",
|
44
|
+
"FasterRCNN-ResNet50-vd-SSLDv2-FPN",
|
45
|
+
"FasterRCNN-ResNet50",
|
46
|
+
"FasterRCNN-Swin-Tiny-FPN",
|
47
|
+
"MaskFormer_small",
|
48
|
+
"MaskFormer_tiny",
|
49
|
+
"SLANeXt_wired",
|
50
|
+
"SLANeXt_wireless",
|
51
|
+
"SLANet",
|
52
|
+
"SLANet_plus",
|
53
|
+
"YOWO",
|
54
|
+
"SAM-H_box",
|
55
|
+
"SAM-H_point",
|
22
56
|
"PP-FormulaNet_plus-L",
|
23
57
|
"PP-FormulaNet_plus-M",
|
24
58
|
"PP-FormulaNet_plus-S",
|
@@ -12,11 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import os
|
16
|
+
import shutil
|
17
|
+
import tempfile
|
18
|
+
from functools import lru_cache
|
15
19
|
from pathlib import Path
|
16
20
|
|
21
|
+
import huggingface_hub as hf_hub
|
22
|
+
|
23
|
+
hf_hub.logging.set_verbosity_error()
|
24
|
+
|
25
|
+
import requests
|
26
|
+
|
17
27
|
from ...utils import logging
|
18
28
|
from ...utils.cache import CACHE_DIR
|
19
29
|
from ...utils.download import download_and_extract
|
30
|
+
from ...utils.flags import MODEL_SOURCE
|
20
31
|
|
21
32
|
OFFICIAL_MODELS = {
|
22
33
|
"ResNet18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/ResNet18_infer.tar",
|
@@ -352,17 +363,127 @@ PP-OCRv5_mobile_rec_infer.tar",
|
|
352
363
|
}
|
353
364
|
|
354
365
|
|
366
|
+
HUGGINGFACE_MODELS = [
|
367
|
+
"arabic_PP-OCRv3_mobile_rec",
|
368
|
+
"chinese_cht_PP-OCRv3_mobile_rec",
|
369
|
+
"ch_RepSVTR_rec",
|
370
|
+
"ch_SVTRv2_rec",
|
371
|
+
"cyrillic_PP-OCRv3_mobile_rec",
|
372
|
+
"devanagari_PP-OCRv3_mobile_rec",
|
373
|
+
"en_PP-OCRv3_mobile_rec",
|
374
|
+
"en_PP-OCRv4_mobile_rec",
|
375
|
+
"japan_PP-OCRv3_mobile_rec",
|
376
|
+
"ka_PP-OCRv3_mobile_rec",
|
377
|
+
"korean_PP-OCRv3_mobile_rec",
|
378
|
+
"LaTeX_OCR_rec",
|
379
|
+
"latin_PP-OCRv3_mobile_rec",
|
380
|
+
"PicoDet_layout_1x",
|
381
|
+
"PicoDet_layout_1x_table",
|
382
|
+
"PicoDet-L_layout_17cls",
|
383
|
+
"PicoDet-L_layout_3cls",
|
384
|
+
"PicoDet-S_layout_17cls",
|
385
|
+
"PicoDet-S_layout_3cls",
|
386
|
+
"PP-DocBee2-3B",
|
387
|
+
"PP-DocBee-2B",
|
388
|
+
"PP-DocBee-7B",
|
389
|
+
"PP-DocBlockLayout",
|
390
|
+
"PP-DocLayout-L",
|
391
|
+
"PP-DocLayout-M",
|
392
|
+
"PP-DocLayout_plus-L",
|
393
|
+
"PP-DocLayout-S",
|
394
|
+
"PP-FormulaNet-L",
|
395
|
+
"PP-FormulaNet_plus-L",
|
396
|
+
"PP-FormulaNet_plus-M",
|
397
|
+
"PP-FormulaNet_plus-S",
|
398
|
+
"PP-FormulaNet-S",
|
399
|
+
"PP-LCNet_x1_0_doc_ori",
|
400
|
+
"PP-LCNet_x1_0_table_cls",
|
401
|
+
"PP-OCRv3_mobile_det",
|
402
|
+
"PP-OCRv3_mobile_rec",
|
403
|
+
"PP-OCRv3_server_det",
|
404
|
+
"PP-OCRv4_mobile_det",
|
405
|
+
"PP-OCRv4_mobile_rec",
|
406
|
+
"PP-OCRv4_mobile_seal_det",
|
407
|
+
"PP-OCRv4_server_det",
|
408
|
+
"PP-OCRv4_server_rec_doc",
|
409
|
+
"PP-OCRv4_server_rec",
|
410
|
+
"PP-OCRv4_server_seal_det",
|
411
|
+
"PP-OCRv5_mobile_det",
|
412
|
+
"PP-OCRv5_mobile_rec",
|
413
|
+
"PP-OCRv5_server_det",
|
414
|
+
"PP-OCRv5_server_rec",
|
415
|
+
"RT-DETR-H_layout_17cls",
|
416
|
+
"RT-DETR-H_layout_3cls",
|
417
|
+
"RT-DETR-L_wired_table_cell_det",
|
418
|
+
"RT-DETR-L_wireless_table_cell_det",
|
419
|
+
"SLANet",
|
420
|
+
"SLANet_plus",
|
421
|
+
"SLANeXt_wired",
|
422
|
+
"SLANeXt_wireless",
|
423
|
+
"ta_PP-OCRv3_mobile_rec",
|
424
|
+
"te_PP-OCRv3_mobile_rec",
|
425
|
+
"UniMERNet",
|
426
|
+
"UVDoc",
|
427
|
+
]
|
428
|
+
|
429
|
+
|
430
|
+
@lru_cache(1)
|
431
|
+
def is_huggingface_accessible():
|
432
|
+
try:
|
433
|
+
response = requests.get("https://huggingface.co", timeout=1)
|
434
|
+
return response.ok == True
|
435
|
+
except requests.exceptions.RequestException as e:
|
436
|
+
return False
|
437
|
+
|
438
|
+
|
355
439
|
class OfficialModelsDict(dict):
|
356
440
|
"""Official Models Dict"""
|
357
441
|
|
442
|
+
_save_dir = Path(CACHE_DIR) / "official_models"
|
443
|
+
|
358
444
|
def __getitem__(self, key):
|
359
|
-
|
360
|
-
|
445
|
+
def _download_from_bos():
|
446
|
+
url = super(OfficialModelsDict, self).__getitem__(key)
|
447
|
+
download_and_extract(url, self._save_dir, f"{key}", overwrite=False)
|
448
|
+
return self._save_dir / f"{key}"
|
449
|
+
|
450
|
+
def _download_from_hf():
|
451
|
+
local_dir = self._save_dir / f"{key}"
|
452
|
+
try:
|
453
|
+
if os.path.exists(local_dir):
|
454
|
+
hf_hub.snapshot_download(
|
455
|
+
repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
|
456
|
+
)
|
457
|
+
else:
|
458
|
+
with tempfile.TemporaryDirectory() as td:
|
459
|
+
temp_dir = os.path.join(td, "temp_dir")
|
460
|
+
hf_hub.snapshot_download(
|
461
|
+
repo_id=f"PaddlePaddle/{key}", local_dir=temp_dir
|
462
|
+
)
|
463
|
+
shutil.move(temp_dir, local_dir)
|
464
|
+
except Exception as e:
|
465
|
+
logging.warning(
|
466
|
+
f"Encounter exception when download model from huggingface: \n{e}.\nPaddleX would try to download from BOS."
|
467
|
+
)
|
468
|
+
return _download_from_bos()
|
469
|
+
return local_dir
|
470
|
+
|
361
471
|
logging.info(
|
362
|
-
f"Using official model ({key}), the model files will be automatically downloaded and saved in {
|
472
|
+
f"Using official model ({key}), the model files will be automatically downloaded and saved in {self._save_dir}."
|
363
473
|
)
|
364
|
-
|
365
|
-
|
474
|
+
|
475
|
+
if (
|
476
|
+
MODEL_SOURCE.lower() == "huggingface"
|
477
|
+
and is_huggingface_accessible()
|
478
|
+
and key in HUGGINGFACE_MODELS
|
479
|
+
):
|
480
|
+
return _download_from_hf()
|
481
|
+
elif MODEL_SOURCE.lower() == "modelscope":
|
482
|
+
raise Exception(
|
483
|
+
f"ModelScope is not supported! Please use `HuggingFace` or `BOS`."
|
484
|
+
)
|
485
|
+
else:
|
486
|
+
return _download_from_bos()
|
366
487
|
|
367
488
|
|
368
489
|
official_models = OfficialModelsDict(OFFICIAL_MODELS)
|
@@ -23,13 +23,34 @@ from ...utils.device import (
|
|
23
23
|
parse_device,
|
24
24
|
set_env_for_device_type,
|
25
25
|
)
|
26
|
-
from ...utils.flags import
|
26
|
+
from ...utils.flags import (
|
27
|
+
DISABLE_MKLDNN_MODEL_BL,
|
28
|
+
DISABLE_TRT_MODEL_BL,
|
29
|
+
ENABLE_MKLDNN_BYDEFAULT,
|
30
|
+
USE_PIR_TRT,
|
31
|
+
)
|
32
|
+
from .misc import is_mkldnn_available
|
27
33
|
from .mkldnn_blocklist import MKLDNN_BLOCKLIST
|
28
34
|
from .new_ir_blocklist import NEWIR_BLOCKLIST
|
29
35
|
from .trt_blocklist import TRT_BLOCKLIST
|
30
36
|
from .trt_config import TRT_CFG_SETTING, TRT_PRECISION_MAP
|
31
37
|
|
32
38
|
|
39
|
+
def get_default_run_mode(model_name, device_type):
|
40
|
+
if not model_name:
|
41
|
+
return "paddle"
|
42
|
+
if device_type != "cpu":
|
43
|
+
return "paddle"
|
44
|
+
if (
|
45
|
+
ENABLE_MKLDNN_BYDEFAULT
|
46
|
+
and is_mkldnn_available()
|
47
|
+
and model_name not in MKLDNN_BLOCKLIST
|
48
|
+
):
|
49
|
+
return "mkldnn"
|
50
|
+
else:
|
51
|
+
return "paddle"
|
52
|
+
|
53
|
+
|
33
54
|
class PaddlePredictorOption(object):
|
34
55
|
"""Paddle Inference Engine Option"""
|
35
56
|
|
@@ -104,7 +125,7 @@ class PaddlePredictorOption(object):
|
|
104
125
|
device_type, device_ids = parse_device(get_default_device())
|
105
126
|
|
106
127
|
default_config = {
|
107
|
-
"run_mode":
|
128
|
+
"run_mode": get_default_run_mode(self.model_name, device_type),
|
108
129
|
"device_type": device_type,
|
109
130
|
"device_id": None if device_ids is None else device_ids[0],
|
110
131
|
"cpu_threads": 8,
|
@@ -119,6 +140,7 @@ class PaddlePredictorOption(object):
|
|
119
140
|
"trt_dynamic_shape_input_data": None, # only for trt
|
120
141
|
"trt_shape_range_info_path": None, # only for trt
|
121
142
|
"trt_allow_rebuild_at_runtime": True, # only for trt
|
143
|
+
"mkldnn_cache_capacity": 10,
|
122
144
|
}
|
123
145
|
return default_config
|
124
146
|
|
@@ -139,15 +161,29 @@ class PaddlePredictorOption(object):
|
|
139
161
|
f"`run_mode` must be {support_run_mode_str}, but received {repr(run_mode)}."
|
140
162
|
)
|
141
163
|
|
164
|
+
if run_mode.startswith("mkldnn") and not is_mkldnn_available():
|
165
|
+
logging.warning("MKL-DNN is not available. Using `paddle` instead.")
|
166
|
+
run_mode = "paddle"
|
167
|
+
|
168
|
+
# TODO: Check if trt is available
|
169
|
+
|
142
170
|
if self._model_name is not None:
|
143
171
|
# TRT Blocklist
|
144
|
-
if
|
172
|
+
if (
|
173
|
+
not DISABLE_TRT_MODEL_BL
|
174
|
+
and run_mode.startswith("trt")
|
175
|
+
and self._model_name in TRT_BLOCKLIST
|
176
|
+
):
|
145
177
|
logging.warning(
|
146
178
|
f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
|
147
179
|
)
|
148
180
|
run_mode = "paddle"
|
149
181
|
# MKLDNN Blocklist
|
150
|
-
elif
|
182
|
+
elif (
|
183
|
+
not DISABLE_MKLDNN_MODEL_BL
|
184
|
+
and run_mode.startswith("mkldnn")
|
185
|
+
and self._model_name in MKLDNN_BLOCKLIST
|
186
|
+
):
|
151
187
|
logging.warning(
|
152
188
|
f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
|
153
189
|
)
|
@@ -294,6 +330,14 @@ class PaddlePredictorOption(object):
|
|
294
330
|
def trt_allow_rebuild_at_runtime(self, trt_allow_rebuild_at_runtime):
|
295
331
|
self._update("trt_allow_rebuild_at_runtime", trt_allow_rebuild_at_runtime)
|
296
332
|
|
333
|
+
@property
|
334
|
+
def mkldnn_cache_capacity(self):
|
335
|
+
return self._cfg["mkldnn_cache_capacity"]
|
336
|
+
|
337
|
+
@mkldnn_cache_capacity.setter
|
338
|
+
def mkldnn_cache_capacity(self, capacity: int):
|
339
|
+
self._update("mkldnn_cache_capacity", capacity)
|
340
|
+
|
297
341
|
# For backward compatibility
|
298
342
|
# TODO: Issue deprecation warnings
|
299
343
|
@property
|
@@ -38,8 +38,18 @@ class SegDatasetChecker(BaseDatasetChecker):
|
|
38
38
|
str: the root directory of dataset.
|
39
39
|
"""
|
40
40
|
anno_dirs = list(Path(dataset_dir).glob("**/images"))
|
41
|
-
|
42
|
-
|
41
|
+
if len(anno_dirs) == 1:
|
42
|
+
dataset_dir = anno_dirs[0].parent.as_posix()
|
43
|
+
elif len(anno_dirs) == 0:
|
44
|
+
dataset_dir = Path(dataset_dir)
|
45
|
+
else:
|
46
|
+
raise ValueError(
|
47
|
+
f"Segmentation Dataset Format Error: We currently only support `PaddleX` and `Labelme` formats. "
|
48
|
+
f"For `PaddleX` format, your dataset root must contain exactly one `images` directory. "
|
49
|
+
f"For `Labelme` format, your dataset root must contain no `images` directories. "
|
50
|
+
f"However, your dataset root contains {len(anno_dirs)} `images` directories. "
|
51
|
+
f"Please adjust your dataset structure to comply with the supported formats."
|
52
|
+
)
|
43
53
|
return dataset_dir
|
44
54
|
|
45
55
|
def convert_dataset(self, src_dataset_dir: str) -> str:
|
paddlex/ops/__init__.py
CHANGED
@@ -66,11 +66,14 @@ class CustomOpNotFoundException(Exception):
|
|
66
66
|
|
67
67
|
|
68
68
|
class CustomOperatorPathFinder:
|
69
|
-
def
|
69
|
+
def find_spec(self, fullname: str, path, target=None):
|
70
70
|
if not fullname.startswith("paddlex.ops"):
|
71
71
|
return None
|
72
|
-
|
73
|
-
|
72
|
+
return importlib.machinery.ModuleSpec(
|
73
|
+
name=fullname,
|
74
|
+
loader=CustomOperatorPathLoader(),
|
75
|
+
is_package=False,
|
76
|
+
)
|
74
77
|
|
75
78
|
|
76
79
|
class CustomOperatorPathLoader:
|
paddlex/utils/deps.py
CHANGED
@@ -73,7 +73,7 @@ def _get_dep_specs():
|
|
73
73
|
DEP_SPECS = _get_dep_specs()
|
74
74
|
|
75
75
|
|
76
|
-
def
|
76
|
+
def get_dep_version(dep):
|
77
77
|
try:
|
78
78
|
return importlib.metadata.version(dep)
|
79
79
|
except importlib.metadata.PackageNotFoundError:
|
@@ -101,7 +101,7 @@ def is_dep_available(dep, /, check_version=None):
|
|
101
101
|
check_version = True
|
102
102
|
else:
|
103
103
|
check_version = False
|
104
|
-
version =
|
104
|
+
version = get_dep_version(dep)
|
105
105
|
if version is None:
|
106
106
|
return False
|
107
107
|
if check_version:
|
paddlex/utils/device.py
CHANGED
@@ -15,8 +15,6 @@
|
|
15
15
|
import os
|
16
16
|
from contextlib import ContextDecorator
|
17
17
|
|
18
|
-
import GPUtil
|
19
|
-
|
20
18
|
from . import logging
|
21
19
|
from .custom_device_list import (
|
22
20
|
DCU_WHITELIST,
|
@@ -41,25 +39,12 @@ def constr_device(device_type, device_ids):
|
|
41
39
|
|
42
40
|
|
43
41
|
def get_default_device():
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
"Failed to query GPU devices. Falling back to CPU.", exc_info=True
|
49
|
-
)
|
50
|
-
has_gpus = False
|
42
|
+
import paddle
|
43
|
+
|
44
|
+
if paddle.device.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0:
|
45
|
+
return constr_device("gpu", [0])
|
51
46
|
else:
|
52
|
-
has_gpus = bool(gpu_list)
|
53
|
-
if not has_gpus:
|
54
|
-
# HACK
|
55
|
-
if os.path.exists("/etc/nv_tegra_release"):
|
56
|
-
logging.debug(
|
57
|
-
"The current device appears to be an NVIDIA Jetson. GPU 0 will be used as the default device."
|
58
|
-
)
|
59
|
-
if not has_gpus:
|
60
47
|
return "cpu"
|
61
|
-
else:
|
62
|
-
return constr_device("gpu", [0])
|
63
48
|
|
64
49
|
|
65
50
|
def parse_device(device):
|
paddlex/utils/flags.py
CHANGED
@@ -51,7 +51,16 @@ FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", True)
|
|
51
51
|
USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", True)
|
52
52
|
DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
|
53
53
|
DISABLE_CINN_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_CINN_MODEL_WL", False)
|
54
|
+
DISABLE_TRT_MODEL_BL = get_flag_from_env_var("PADDLE_PDX_DISABLE_TRT_MODEL_BL", False)
|
55
|
+
DISABLE_MKLDNN_MODEL_BL = get_flag_from_env_var(
|
56
|
+
"PADDLE_PDX_DISABLE_MKLDNN_MODEL_BL", False
|
57
|
+
)
|
54
58
|
LOCAL_FONT_FILE_PATH = get_flag_from_env_var("PADDLE_PDX_LOCAL_FONT_FILE_PATH", None)
|
59
|
+
ENABLE_MKLDNN_BYDEFAULT = get_flag_from_env_var(
|
60
|
+
"PADDLE_PDX_ENABLE_MKLDNN_BYDEFAULT", True
|
61
|
+
)
|
62
|
+
|
63
|
+
MODEL_SOURCE = os.environ.get("PADDLE_PDX_MODEL_SOURCE", "huggingface")
|
55
64
|
|
56
65
|
|
57
66
|
# Inference Benchmark
|
@@ -46,7 +46,7 @@ class AutoRegisterMetaClass(type):
|
|
46
46
|
if bases:
|
47
47
|
for base in bases:
|
48
48
|
base_cls = mcs.__find_base_class(base)
|
49
|
-
if base_cls:
|
49
|
+
if base_cls and hasattr(cls, mcs.__model_type_attr_name):
|
50
50
|
mcs.__register_to_base_class(base_cls, cls)
|
51
51
|
|
52
52
|
@classmethod
|
@@ -64,7 +64,7 @@ class AutoRegisterMetaClass(type):
|
|
64
64
|
|
65
65
|
@classmethod
|
66
66
|
def __register_to_base_class(mcs, base, cls):
|
67
|
-
cls_entity_name = getattr(cls, mcs.__model_type_attr_name
|
67
|
+
cls_entity_name = getattr(cls, mcs.__model_type_attr_name)
|
68
68
|
if isinstance(cls_entity_name, str):
|
69
69
|
cls_entity_name = [cls_entity_name]
|
70
70
|
|