paddlex 3.0.1__py3-none-any.whl → 3.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. paddlex/.version +1 -1
  2. paddlex/inference/models/base/predictor/base_predictor.py +2 -0
  3. paddlex/inference/models/common/static_infer.py +20 -14
  4. paddlex/inference/models/common/ts/funcs.py +19 -8
  5. paddlex/inference/models/formula_recognition/predictor.py +1 -1
  6. paddlex/inference/models/formula_recognition/processors.py +2 -2
  7. paddlex/inference/models/text_recognition/result.py +1 -1
  8. paddlex/inference/pipelines/layout_parsing/layout_objects.py +859 -0
  9. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +144 -205
  10. paddlex/inference/pipelines/layout_parsing/result_v2.py +13 -272
  11. paddlex/inference/pipelines/layout_parsing/setting.py +1 -0
  12. paddlex/inference/pipelines/layout_parsing/utils.py +108 -312
  13. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +302 -247
  14. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +156 -104
  15. paddlex/inference/pipelines/ocr/result.py +2 -2
  16. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +1 -1
  17. paddlex/inference/serving/basic_serving/_app.py +47 -13
  18. paddlex/inference/serving/infra/utils.py +22 -17
  19. paddlex/inference/utils/hpi.py +60 -25
  20. paddlex/inference/utils/hpi_model_info_collection.json +627 -204
  21. paddlex/inference/utils/misc.py +20 -0
  22. paddlex/inference/utils/mkldnn_blocklist.py +36 -2
  23. paddlex/inference/utils/official_models.py +126 -5
  24. paddlex/inference/utils/pp_option.py +81 -21
  25. paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +12 -2
  26. paddlex/ops/__init__.py +6 -3
  27. paddlex/utils/deps.py +2 -2
  28. paddlex/utils/device.py +4 -19
  29. paddlex/utils/download.py +10 -7
  30. paddlex/utils/flags.py +9 -0
  31. paddlex/utils/subclass_register.py +2 -2
  32. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/METADATA +307 -162
  33. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/RECORD +37 -35
  34. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/WHEEL +1 -1
  35. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/entry_points.txt +1 -0
  36. {paddlex-3.0.1.dist-info/licenses → paddlex-3.0.3.dist-info}/LICENSE +0 -0
  37. {paddlex-3.0.1.dist-info → paddlex-3.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ def is_mkldnn_available():
17
+ # XXX: Not sure if this is the best way to check if MKL-DNN is available
18
+ from paddle.inference import Config
19
+
20
+ return hasattr(Config, "set_mkldnn_cache_capacity")
@@ -13,12 +13,46 @@
13
13
  # limitations under the License.
14
14
 
15
15
  MKLDNN_BLOCKLIST = [
16
- "SLANeXt_wired",
17
- "SLANeXt_wireless",
18
16
  "LaTeX_OCR_rec",
19
17
  "PP-FormulaNet-L",
20
18
  "PP-FormulaNet-S",
21
19
  "UniMERNet",
20
+ "UVDoc",
21
+ "Cascade-MaskRCNN-ResNet50-FPN",
22
+ "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
23
+ "Mask-RT-DETR-M",
24
+ "Mask-RT-DETR-S",
25
+ "MaskRCNN-ResNeXt101-vd-FPN",
26
+ "MaskRCNN-ResNet101-FPN",
27
+ "MaskRCNN-ResNet101-vd-FPN",
28
+ "MaskRCNN-ResNet50-FPN",
29
+ "MaskRCNN-ResNet50-vd-FPN",
30
+ "MaskRCNN-ResNet50",
31
+ "SOLOv2",
32
+ "PP-TinyPose_128x96",
33
+ "PP-TinyPose_256x192",
34
+ "Cascade-FasterRCNN-ResNet50-FPN",
35
+ "Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN",
36
+ "Co-DINO-Swin-L",
37
+ "Co-Deformable-DETR-Swin-T",
38
+ "FasterRCNN-ResNeXt101-vd-FPN",
39
+ "FasterRCNN-ResNet101-FPN",
40
+ "FasterRCNN-ResNet101",
41
+ "FasterRCNN-ResNet34-FPN",
42
+ "FasterRCNN-ResNet50-FPN",
43
+ "FasterRCNN-ResNet50-vd-FPN",
44
+ "FasterRCNN-ResNet50-vd-SSLDv2-FPN",
45
+ "FasterRCNN-ResNet50",
46
+ "FasterRCNN-Swin-Tiny-FPN",
47
+ "MaskFormer_small",
48
+ "MaskFormer_tiny",
49
+ "SLANeXt_wired",
50
+ "SLANeXt_wireless",
51
+ "SLANet",
52
+ "SLANet_plus",
53
+ "YOWO",
54
+ "SAM-H_box",
55
+ "SAM-H_point",
22
56
  "PP-FormulaNet_plus-L",
23
57
  "PP-FormulaNet_plus-M",
24
58
  "PP-FormulaNet_plus-S",
@@ -12,11 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import os
16
+ import shutil
17
+ import tempfile
18
+ from functools import lru_cache
15
19
  from pathlib import Path
16
20
 
21
+ import huggingface_hub as hf_hub
22
+
23
+ hf_hub.logging.set_verbosity_error()
24
+
25
+ import requests
26
+
17
27
  from ...utils import logging
18
28
  from ...utils.cache import CACHE_DIR
19
29
  from ...utils.download import download_and_extract
30
+ from ...utils.flags import MODEL_SOURCE
20
31
 
21
32
  OFFICIAL_MODELS = {
22
33
  "ResNet18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/ResNet18_infer.tar",
@@ -352,17 +363,127 @@ PP-OCRv5_mobile_rec_infer.tar",
352
363
  }
353
364
 
354
365
 
366
+ HUGGINGFACE_MODELS = [
367
+ "arabic_PP-OCRv3_mobile_rec",
368
+ "chinese_cht_PP-OCRv3_mobile_rec",
369
+ "ch_RepSVTR_rec",
370
+ "ch_SVTRv2_rec",
371
+ "cyrillic_PP-OCRv3_mobile_rec",
372
+ "devanagari_PP-OCRv3_mobile_rec",
373
+ "en_PP-OCRv3_mobile_rec",
374
+ "en_PP-OCRv4_mobile_rec",
375
+ "japan_PP-OCRv3_mobile_rec",
376
+ "ka_PP-OCRv3_mobile_rec",
377
+ "korean_PP-OCRv3_mobile_rec",
378
+ "LaTeX_OCR_rec",
379
+ "latin_PP-OCRv3_mobile_rec",
380
+ "PicoDet_layout_1x",
381
+ "PicoDet_layout_1x_table",
382
+ "PicoDet-L_layout_17cls",
383
+ "PicoDet-L_layout_3cls",
384
+ "PicoDet-S_layout_17cls",
385
+ "PicoDet-S_layout_3cls",
386
+ "PP-DocBee2-3B",
387
+ "PP-DocBee-2B",
388
+ "PP-DocBee-7B",
389
+ "PP-DocBlockLayout",
390
+ "PP-DocLayout-L",
391
+ "PP-DocLayout-M",
392
+ "PP-DocLayout_plus-L",
393
+ "PP-DocLayout-S",
394
+ "PP-FormulaNet-L",
395
+ "PP-FormulaNet_plus-L",
396
+ "PP-FormulaNet_plus-M",
397
+ "PP-FormulaNet_plus-S",
398
+ "PP-FormulaNet-S",
399
+ "PP-LCNet_x1_0_doc_ori",
400
+ "PP-LCNet_x1_0_table_cls",
401
+ "PP-OCRv3_mobile_det",
402
+ "PP-OCRv3_mobile_rec",
403
+ "PP-OCRv3_server_det",
404
+ "PP-OCRv4_mobile_det",
405
+ "PP-OCRv4_mobile_rec",
406
+ "PP-OCRv4_mobile_seal_det",
407
+ "PP-OCRv4_server_det",
408
+ "PP-OCRv4_server_rec_doc",
409
+ "PP-OCRv4_server_rec",
410
+ "PP-OCRv4_server_seal_det",
411
+ "PP-OCRv5_mobile_det",
412
+ "PP-OCRv5_mobile_rec",
413
+ "PP-OCRv5_server_det",
414
+ "PP-OCRv5_server_rec",
415
+ "RT-DETR-H_layout_17cls",
416
+ "RT-DETR-H_layout_3cls",
417
+ "RT-DETR-L_wired_table_cell_det",
418
+ "RT-DETR-L_wireless_table_cell_det",
419
+ "SLANet",
420
+ "SLANet_plus",
421
+ "SLANeXt_wired",
422
+ "SLANeXt_wireless",
423
+ "ta_PP-OCRv3_mobile_rec",
424
+ "te_PP-OCRv3_mobile_rec",
425
+ "UniMERNet",
426
+ "UVDoc",
427
+ ]
428
+
429
+
430
+ @lru_cache(1)
431
+ def is_huggingface_accessible():
432
+ try:
433
+ response = requests.get("https://huggingface.co", timeout=1)
434
+ return response.ok == True
435
+ except requests.exceptions.RequestException as e:
436
+ return False
437
+
438
+
355
439
  class OfficialModelsDict(dict):
356
440
  """Official Models Dict"""
357
441
 
442
+ _save_dir = Path(CACHE_DIR) / "official_models"
443
+
358
444
  def __getitem__(self, key):
359
- url = super().__getitem__(key)
360
- save_dir = Path(CACHE_DIR) / "official_models"
445
+ def _download_from_bos():
446
+ url = super(OfficialModelsDict, self).__getitem__(key)
447
+ download_and_extract(url, self._save_dir, f"{key}", overwrite=False)
448
+ return self._save_dir / f"{key}"
449
+
450
+ def _download_from_hf():
451
+ local_dir = self._save_dir / f"{key}"
452
+ try:
453
+ if os.path.exists(local_dir):
454
+ hf_hub.snapshot_download(
455
+ repo_id=f"PaddlePaddle/{key}", local_dir=local_dir
456
+ )
457
+ else:
458
+ with tempfile.TemporaryDirectory() as td:
459
+ temp_dir = os.path.join(td, "temp_dir")
460
+ hf_hub.snapshot_download(
461
+ repo_id=f"PaddlePaddle/{key}", local_dir=temp_dir
462
+ )
463
+ shutil.move(temp_dir, local_dir)
464
+ except Exception as e:
465
+ logging.warning(
466
+ f"Encounter exception when download model from huggingface: \n{e}.\nPaddleX would try to download from BOS."
467
+ )
468
+ return _download_from_bos()
469
+ return local_dir
470
+
361
471
  logging.info(
362
- f"Using official model ({key}), the model files will be automatically downloaded and saved in {save_dir}."
472
+ f"Using official model ({key}), the model files will be automatically downloaded and saved in {self._save_dir}."
363
473
  )
364
- download_and_extract(url, save_dir, f"{key}", overwrite=False)
365
- return save_dir / f"{key}"
474
+
475
+ if (
476
+ MODEL_SOURCE.lower() == "huggingface"
477
+ and is_huggingface_accessible()
478
+ and key in HUGGINGFACE_MODELS
479
+ ):
480
+ return _download_from_hf()
481
+ elif MODEL_SOURCE.lower() == "modelscope":
482
+ raise Exception(
483
+ f"ModelScope is not supported! Please use `HuggingFace` or `BOS`."
484
+ )
485
+ else:
486
+ return _download_from_bos()
366
487
 
367
488
 
368
489
  official_models = OfficialModelsDict(OFFICIAL_MODELS)
@@ -23,13 +23,34 @@ from ...utils.device import (
23
23
  parse_device,
24
24
  set_env_for_device_type,
25
25
  )
26
- from ...utils.flags import USE_PIR_TRT
26
+ from ...utils.flags import (
27
+ DISABLE_MKLDNN_MODEL_BL,
28
+ DISABLE_TRT_MODEL_BL,
29
+ ENABLE_MKLDNN_BYDEFAULT,
30
+ USE_PIR_TRT,
31
+ )
32
+ from .misc import is_mkldnn_available
27
33
  from .mkldnn_blocklist import MKLDNN_BLOCKLIST
28
34
  from .new_ir_blocklist import NEWIR_BLOCKLIST
29
35
  from .trt_blocklist import TRT_BLOCKLIST
30
36
  from .trt_config import TRT_CFG_SETTING, TRT_PRECISION_MAP
31
37
 
32
38
 
39
+ def get_default_run_mode(model_name, device_type):
40
+ if not model_name:
41
+ return "paddle"
42
+ if device_type != "cpu":
43
+ return "paddle"
44
+ if (
45
+ ENABLE_MKLDNN_BYDEFAULT
46
+ and is_mkldnn_available()
47
+ and model_name not in MKLDNN_BLOCKLIST
48
+ ):
49
+ return "mkldnn"
50
+ else:
51
+ return "paddle"
52
+
53
+
33
54
  class PaddlePredictorOption(object):
34
55
  """Paddle Inference Engine Option"""
35
56
 
@@ -48,6 +69,7 @@ class PaddlePredictorOption(object):
48
69
 
49
70
  def __init__(self, model_name=None, **kwargs):
50
71
  super().__init__()
72
+ self._is_default_run_mode = True
51
73
  self._model_name = model_name
52
74
  self._cfg = {}
53
75
  self._init_option(**kwargs)
@@ -85,6 +107,10 @@ class PaddlePredictorOption(object):
85
107
  raise Exception(
86
108
  f"{k} is not supported to set! The supported option is: {self._get_settable_attributes()}"
87
109
  )
110
+
111
+ if "run_mode" in self._cfg:
112
+ self._is_default_run_mode = False
113
+
88
114
  for k, v in self._get_default_config().items():
89
115
  self._cfg.setdefault(k, v)
90
116
 
@@ -101,12 +127,16 @@ class PaddlePredictorOption(object):
101
127
 
102
128
  def _get_default_config(self):
103
129
  """get default config"""
104
- device_type, device_ids = parse_device(get_default_device())
130
+ if self.device_type is None:
131
+ device_type, device_ids = parse_device(get_default_device())
132
+ device_id = None if device_ids is None else device_ids[0]
133
+ else:
134
+ device_type, device_id = self.device_type, self.device_id
105
135
 
106
136
  default_config = {
107
- "run_mode": "paddle",
137
+ "run_mode": get_default_run_mode(self.model_name, device_type),
108
138
  "device_type": device_type,
109
- "device_id": None if device_ids is None else device_ids[0],
139
+ "device_id": device_id,
110
140
  "cpu_threads": 8,
111
141
  "delete_pass": [],
112
142
  "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
@@ -119,6 +149,7 @@ class PaddlePredictorOption(object):
119
149
  "trt_dynamic_shape_input_data": None, # only for trt
120
150
  "trt_shape_range_info_path": None, # only for trt
121
151
  "trt_allow_rebuild_at_runtime": True, # only for trt
152
+ "mkldnn_cache_capacity": 10,
122
153
  }
123
154
  return default_config
124
155
 
@@ -126,9 +157,15 @@ class PaddlePredictorOption(object):
126
157
  self._cfg[k] = v
127
158
  self.changed = True
128
159
 
160
+ def reset_run_mode_by_default(self, model_name=None, device_type=None):
161
+ if self._is_default_run_mode:
162
+ model_name = model_name or self.model_name
163
+ device_type = device_type or self.device_type
164
+ self._update("run_mode", get_default_run_mode(model_name, device_type))
165
+
129
166
  @property
130
167
  def run_mode(self):
131
- return self._cfg["run_mode"]
168
+ return self._cfg.get("run_mode")
132
169
 
133
170
  @run_mode.setter
134
171
  def run_mode(self, run_mode: str):
@@ -139,25 +176,40 @@ class PaddlePredictorOption(object):
139
176
  f"`run_mode` must be {support_run_mode_str}, but received {repr(run_mode)}."
140
177
  )
141
178
 
179
+ if run_mode.startswith("mkldnn") and not is_mkldnn_available():
180
+ logging.warning("MKL-DNN is not available. Using `paddle` instead.")
181
+ run_mode = "paddle"
182
+
183
+ # TODO: Check if trt is available
184
+
142
185
  if self._model_name is not None:
143
186
  # TRT Blocklist
144
- if run_mode.startswith("trt") and self._model_name in TRT_BLOCKLIST:
187
+ if (
188
+ not DISABLE_TRT_MODEL_BL
189
+ and run_mode.startswith("trt")
190
+ and self._model_name in TRT_BLOCKLIST
191
+ ):
145
192
  logging.warning(
146
193
  f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
147
194
  )
148
195
  run_mode = "paddle"
149
196
  # MKLDNN Blocklist
150
- elif run_mode.startswith("mkldnn") and self._model_name in MKLDNN_BLOCKLIST:
197
+ elif (
198
+ not DISABLE_MKLDNN_MODEL_BL
199
+ and run_mode.startswith("mkldnn")
200
+ and self._model_name in MKLDNN_BLOCKLIST
201
+ ):
151
202
  logging.warning(
152
203
  f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
153
204
  )
154
205
  run_mode = "paddle"
155
206
 
207
+ self._is_default_run_mode = False
156
208
  self._update("run_mode", run_mode)
157
209
 
158
210
  @property
159
211
  def device_type(self):
160
- return self._cfg["device_type"]
212
+ return self._cfg.get("device_type")
161
213
 
162
214
  @device_type.setter
163
215
  def device_type(self, device_type):
@@ -175,7 +227,7 @@ class PaddlePredictorOption(object):
175
227
 
176
228
  @property
177
229
  def device_id(self):
178
- return self._cfg["device_id"]
230
+ return self._cfg.get("device_id")
179
231
 
180
232
  @device_id.setter
181
233
  def device_id(self, device_id):
@@ -183,7 +235,7 @@ class PaddlePredictorOption(object):
183
235
 
184
236
  @property
185
237
  def cpu_threads(self):
186
- return self._cfg["cpu_threads"]
238
+ return self._cfg.get("cpu_threads")
187
239
 
188
240
  @cpu_threads.setter
189
241
  def cpu_threads(self, cpu_threads):
@@ -194,7 +246,7 @@ class PaddlePredictorOption(object):
194
246
 
195
247
  @property
196
248
  def delete_pass(self):
197
- return self._cfg["delete_pass"]
249
+ return self._cfg.get("delete_pass")
198
250
 
199
251
  @delete_pass.setter
200
252
  def delete_pass(self, delete_pass):
@@ -202,7 +254,7 @@ class PaddlePredictorOption(object):
202
254
 
203
255
  @property
204
256
  def enable_new_ir(self):
205
- return self._cfg["enable_new_ir"]
257
+ return self._cfg.get("enable_new_ir")
206
258
 
207
259
  @enable_new_ir.setter
208
260
  def enable_new_ir(self, enable_new_ir: bool):
@@ -211,7 +263,7 @@ class PaddlePredictorOption(object):
211
263
 
212
264
  @property
213
265
  def enable_cinn(self):
214
- return self._cfg["enable_cinn"]
266
+ return self._cfg.get("enable_cinn")
215
267
 
216
268
  @enable_cinn.setter
217
269
  def enable_cinn(self, enable_cinn: bool):
@@ -220,7 +272,7 @@ class PaddlePredictorOption(object):
220
272
 
221
273
  @property
222
274
  def trt_cfg_setting(self):
223
- return self._cfg["trt_cfg_setting"]
275
+ return self._cfg.get("trt_cfg_setting")
224
276
 
225
277
  @trt_cfg_setting.setter
226
278
  def trt_cfg_setting(self, config: Dict):
@@ -232,7 +284,7 @@ class PaddlePredictorOption(object):
232
284
 
233
285
  @property
234
286
  def trt_use_dynamic_shapes(self):
235
- return self._cfg["trt_use_dynamic_shapes"]
287
+ return self._cfg.get("trt_use_dynamic_shapes")
236
288
 
237
289
  @trt_use_dynamic_shapes.setter
238
290
  def trt_use_dynamic_shapes(self, trt_use_dynamic_shapes):
@@ -240,7 +292,7 @@ class PaddlePredictorOption(object):
240
292
 
241
293
  @property
242
294
  def trt_collect_shape_range_info(self):
243
- return self._cfg["trt_collect_shape_range_info"]
295
+ return self._cfg.get("trt_collect_shape_range_info")
244
296
 
245
297
  @trt_collect_shape_range_info.setter
246
298
  def trt_collect_shape_range_info(self, trt_collect_shape_range_info):
@@ -248,7 +300,7 @@ class PaddlePredictorOption(object):
248
300
 
249
301
  @property
250
302
  def trt_discard_cached_shape_range_info(self):
251
- return self._cfg["trt_discard_cached_shape_range_info"]
303
+ return self._cfg.get("trt_discard_cached_shape_range_info")
252
304
 
253
305
  @trt_discard_cached_shape_range_info.setter
254
306
  def trt_discard_cached_shape_range_info(self, trt_discard_cached_shape_range_info):
@@ -258,7 +310,7 @@ class PaddlePredictorOption(object):
258
310
 
259
311
  @property
260
312
  def trt_dynamic_shapes(self):
261
- return self._cfg["trt_dynamic_shapes"]
313
+ return self._cfg.get("trt_dynamic_shapes")
262
314
 
263
315
  @trt_dynamic_shapes.setter
264
316
  def trt_dynamic_shapes(self, trt_dynamic_shapes: Dict[str, List[List[int]]]):
@@ -269,7 +321,7 @@ class PaddlePredictorOption(object):
269
321
 
270
322
  @property
271
323
  def trt_dynamic_shape_input_data(self):
272
- return self._cfg["trt_dynamic_shape_input_data"]
324
+ return self._cfg.get("trt_dynamic_shape_input_data")
273
325
 
274
326
  @trt_dynamic_shape_input_data.setter
275
327
  def trt_dynamic_shape_input_data(
@@ -279,7 +331,7 @@ class PaddlePredictorOption(object):
279
331
 
280
332
  @property
281
333
  def trt_shape_range_info_path(self):
282
- return self._cfg["trt_shape_range_info_path"]
334
+ return self._cfg.get("trt_shape_range_info_path")
283
335
 
284
336
  @trt_shape_range_info_path.setter
285
337
  def trt_shape_range_info_path(self, trt_shape_range_info_path: str):
@@ -288,12 +340,20 @@ class PaddlePredictorOption(object):
288
340
 
289
341
  @property
290
342
  def trt_allow_rebuild_at_runtime(self):
291
- return self._cfg["trt_allow_rebuild_at_runtime"]
343
+ return self._cfg.get("trt_allow_rebuild_at_runtime")
292
344
 
293
345
  @trt_allow_rebuild_at_runtime.setter
294
346
  def trt_allow_rebuild_at_runtime(self, trt_allow_rebuild_at_runtime):
295
347
  self._update("trt_allow_rebuild_at_runtime", trt_allow_rebuild_at_runtime)
296
348
 
349
+ @property
350
+ def mkldnn_cache_capacity(self):
351
+ return self._cfg.get("mkldnn_cache_capacity")
352
+
353
+ @mkldnn_cache_capacity.setter
354
+ def mkldnn_cache_capacity(self, capacity: int):
355
+ self._update("mkldnn_cache_capacity", capacity)
356
+
297
357
  # For backward compatibility
298
358
  # TODO: Issue deprecation warnings
299
359
  @property
@@ -38,8 +38,18 @@ class SegDatasetChecker(BaseDatasetChecker):
38
38
  str: the root directory of dataset.
39
39
  """
40
40
  anno_dirs = list(Path(dataset_dir).glob("**/images"))
41
- assert len(anno_dirs) == 1
42
- dataset_dir = anno_dirs[0].parent.as_posix()
41
+ if len(anno_dirs) == 1:
42
+ dataset_dir = anno_dirs[0].parent.as_posix()
43
+ elif len(anno_dirs) == 0:
44
+ dataset_dir = Path(dataset_dir)
45
+ else:
46
+ raise ValueError(
47
+ f"Segmentation Dataset Format Error: We currently only support `PaddleX` and `Labelme` formats. "
48
+ f"For `PaddleX` format, your dataset root must contain exactly one `images` directory. "
49
+ f"For `Labelme` format, your dataset root must contain no `images` directories. "
50
+ f"However, your dataset root contains {len(anno_dirs)} `images` directories. "
51
+ f"Please adjust your dataset structure to comply with the supported formats."
52
+ )
43
53
  return dataset_dir
44
54
 
45
55
  def convert_dataset(self, src_dataset_dir: str) -> str:
paddlex/ops/__init__.py CHANGED
@@ -66,11 +66,14 @@ class CustomOpNotFoundException(Exception):
66
66
 
67
67
 
68
68
  class CustomOperatorPathFinder:
69
- def find_module(self, fullname: str, path: str = None):
69
+ def find_spec(self, fullname: str, path, target=None):
70
70
  if not fullname.startswith("paddlex.ops"):
71
71
  return None
72
-
73
- return CustomOperatorPathLoader()
72
+ return importlib.machinery.ModuleSpec(
73
+ name=fullname,
74
+ loader=CustomOperatorPathLoader(),
75
+ is_package=False,
76
+ )
74
77
 
75
78
 
76
79
  class CustomOperatorPathLoader:
paddlex/utils/deps.py CHANGED
@@ -73,7 +73,7 @@ def _get_dep_specs():
73
73
  DEP_SPECS = _get_dep_specs()
74
74
 
75
75
 
76
- def _get_dep_version(dep):
76
+ def get_dep_version(dep):
77
77
  try:
78
78
  return importlib.metadata.version(dep)
79
79
  except importlib.metadata.PackageNotFoundError:
@@ -101,7 +101,7 @@ def is_dep_available(dep, /, check_version=None):
101
101
  check_version = True
102
102
  else:
103
103
  check_version = False
104
- version = _get_dep_version(dep)
104
+ version = get_dep_version(dep)
105
105
  if version is None:
106
106
  return False
107
107
  if check_version:
paddlex/utils/device.py CHANGED
@@ -15,8 +15,6 @@
15
15
  import os
16
16
  from contextlib import ContextDecorator
17
17
 
18
- import GPUtil
19
-
20
18
  from . import logging
21
19
  from .custom_device_list import (
22
20
  DCU_WHITELIST,
@@ -41,25 +39,12 @@ def constr_device(device_type, device_ids):
41
39
 
42
40
 
43
41
  def get_default_device():
44
- try:
45
- gpu_list = GPUtil.getGPUs()
46
- except Exception:
47
- logging.debug(
48
- "Failed to query GPU devices. Falling back to CPU.", exc_info=True
49
- )
50
- has_gpus = False
42
+ import paddle
43
+
44
+ if paddle.device.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0:
45
+ return constr_device("gpu", [0])
51
46
  else:
52
- has_gpus = bool(gpu_list)
53
- if not has_gpus:
54
- # HACK
55
- if os.path.exists("/etc/nv_tegra_release"):
56
- logging.debug(
57
- "The current device appears to be an NVIDIA Jetson. GPU 0 will be used as the default device."
58
- )
59
- if not has_gpus:
60
47
  return "cpu"
61
- else:
62
- return constr_device("gpu", [0])
63
48
 
64
49
 
65
50
  def parse_device(device):
paddlex/utils/download.py CHANGED
@@ -39,14 +39,14 @@ class _ProgressPrinter(object):
39
39
  str_ += "\n"
40
40
  self._last_time = 0
41
41
  if time.time() - self._last_time >= self._flush_intvl:
42
- sys.stdout.write(f"\r{str_}")
42
+ sys.stderr.write(f"\r{str_}")
43
43
  self._last_time = time.time()
44
- sys.stdout.flush()
44
+ sys.stderr.flush()
45
45
 
46
46
 
47
47
  def _download(url, save_path, print_progress):
48
48
  if print_progress:
49
- print(f"Connecting to {url} ...")
49
+ print(f"Connecting to {url} ...", file=sys.stderr)
50
50
 
51
51
  with requests.get(url, stream=True, timeout=15) as r:
52
52
  r.raise_for_status()
@@ -62,7 +62,10 @@ def _download(url, save_path, print_progress):
62
62
  total_length = int(total_length)
63
63
  if print_progress:
64
64
  printer = _ProgressPrinter()
65
- print(f"Downloading {os.path.basename(save_path)} ...")
65
+ print(
66
+ f"Downloading {os.path.basename(save_path)} ...",
67
+ file=sys.stderr,
68
+ )
66
69
  for data in r.iter_content(chunk_size=4096):
67
70
  dl += len(data)
68
71
  f.write(data)
@@ -95,17 +98,17 @@ def _extract_tar_file(file_path, extd_dir):
95
98
  try:
96
99
  f.extract(file, extd_dir)
97
100
  except KeyError:
98
- print(f"File {file} not found in the archive.")
101
+ print(f"File {file} not found in the archive.", file=sys.stderr)
99
102
  yield total_num, index
100
103
  except Exception as e:
101
- print(f"An error occurred: {e}")
104
+ print(f"An error occurred: {e}", file=sys.stderr)
102
105
 
103
106
 
104
107
  def _extract(file_path, extd_dir, print_progress):
105
108
  """extract"""
106
109
  if print_progress:
107
110
  printer = _ProgressPrinter()
108
- print(f"Extracting {os.path.basename(file_path)}")
111
+ print(f"Extracting {os.path.basename(file_path)}", file=sys.stderr)
109
112
 
110
113
  if zipfile.is_zipfile(file_path):
111
114
  handler = _extract_zip_file
paddlex/utils/flags.py CHANGED
@@ -51,7 +51,16 @@ FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", True)
51
51
  USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", True)
52
52
  DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
53
53
  DISABLE_CINN_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_CINN_MODEL_WL", False)
54
+ DISABLE_TRT_MODEL_BL = get_flag_from_env_var("PADDLE_PDX_DISABLE_TRT_MODEL_BL", False)
55
+ DISABLE_MKLDNN_MODEL_BL = get_flag_from_env_var(
56
+ "PADDLE_PDX_DISABLE_MKLDNN_MODEL_BL", False
57
+ )
54
58
  LOCAL_FONT_FILE_PATH = get_flag_from_env_var("PADDLE_PDX_LOCAL_FONT_FILE_PATH", None)
59
+ ENABLE_MKLDNN_BYDEFAULT = get_flag_from_env_var(
60
+ "PADDLE_PDX_ENABLE_MKLDNN_BYDEFAULT", True
61
+ )
62
+
63
+ MODEL_SOURCE = os.environ.get("PADDLE_PDX_MODEL_SOURCE", "huggingface")
55
64
 
56
65
 
57
66
  # Inference Benchmark
@@ -46,7 +46,7 @@ class AutoRegisterMetaClass(type):
46
46
  if bases:
47
47
  for base in bases:
48
48
  base_cls = mcs.__find_base_class(base)
49
- if base_cls:
49
+ if base_cls and hasattr(cls, mcs.__model_type_attr_name):
50
50
  mcs.__register_to_base_class(base_cls, cls)
51
51
 
52
52
  @classmethod
@@ -64,7 +64,7 @@ class AutoRegisterMetaClass(type):
64
64
 
65
65
  @classmethod
66
66
  def __register_to_base_class(mcs, base, cls):
67
- cls_entity_name = getattr(cls, mcs.__model_type_attr_name, cls.__name__)
67
+ cls_entity_name = getattr(cls, mcs.__model_type_attr_name)
68
68
  if isinstance(cls_entity_name, str):
69
69
  cls_entity_name = [cls_entity_name]
70
70