mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,7 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  from msprobe.pytorch.parse_tool.lib.interactive_cli import InteractiveCli
18
17
  from msprobe.pytorch.common.log import logger
19
18
 
@@ -22,7 +22,7 @@ from collections import namedtuple
22
22
  from msprobe.pytorch.parse_tool.lib.utils import Util
23
23
  from msprobe.pytorch.parse_tool.lib.config import Const
24
24
  from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
25
- from msprobe.core.common.file_utils import FileChecker, create_directory, load_npy, save_npy_to_txt, write_csv
25
+ from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files
26
26
 
27
27
 
28
28
  class Compare:
@@ -49,10 +49,10 @@ class Compare:
49
49
  dump_file = self.util.path_strip(dump_file)
50
50
  file_name = ""
51
51
  if os.path.isfile(dump_file):
52
- self.log.info("Covert file is: %s", dump_file)
52
+ self.log.info("Covert file is: %s" % dump_file)
53
53
  file_name = os.path.basename(dump_file)
54
54
  elif os.path.isdir(dump_file):
55
- self.log.info("Convert all files in path: %s", dump_file)
55
+ self.log.info("Convert all files in path: %s" % dump_file)
56
56
  file_name = ""
57
57
  output = output if output else Const.DUMP_CONVERT_DIR
58
58
  convert = self.convert(dump_file, data_format, output, msaccucmp_path)
@@ -62,7 +62,7 @@ class Compare:
62
62
  summary_txt = ["SrcFile: %s" % dump_file]
63
63
  for convert_file in convert_files.values():
64
64
  summary_txt.append(" - %s" % convert_file.file_name)
65
- self.log.info("Transfer result is saved in : %s", os.path.realpath(output))
65
+ self.log.info("Transfer result is saved in : %s" % os.path.realpath(output))
66
66
  self.util.print_panel("\n".join(summary_txt))
67
67
 
68
68
  def convert(self, dump_file, data_format, output, msaccucmp_path):
@@ -114,11 +114,11 @@ class Compare:
114
114
  shape_left = data_left.shape
115
115
  shape_right = data_right.shape
116
116
  if shape_left != shape_right:
117
- self.log.warning("Data shape not equal: %s vs %s", data_left.shape, data_right.shape)
117
+ self.log.warning("Data shape not equal: %s vs %s" % (data_left.shape, data_right.shape))
118
118
  data_left = data_left.reshape(-1)
119
119
  data_right = data_right.reshape(-1)
120
120
  if data_left.shape[0] != data_right.shape[0]:
121
- self.log.warning("Data size not equal: %s vs %s", data_left.shape, data_right.shape)
121
+ self.log.warning("Data size not equal: %s vs %s" % (data_left.shape, data_right.shape))
122
122
  if data_left.shape[0] < data_right.shape[0]:
123
123
  data_left = np.pad(data_left, (0, data_right.shape[0] - data_left.shape[0]), 'constant')
124
124
  else:
@@ -160,7 +160,7 @@ class Compare:
160
160
  if shape != bench_shape or dtype != bench_dtype:
161
161
  self.log.error(
162
162
  "Shape or dtype between two npy files is inconsistent. Please check the two files."
163
- "File 1: %s, file 2: %s", file, bench_file)
163
+ "File 1: %s, file 2: %s" % (file, bench_file))
164
164
  self.util.deal_with_dir_or_file_inconsistency(output_path)
165
165
  return
166
166
  md5_consistency = False
@@ -236,25 +236,18 @@ class Compare:
236
236
  golden_subdir_path = os.path.join(golden_dump_dir, golden_subdir_name)
237
237
  self.compare_timestamp_directory(my_subdir_path, golden_subdir_path, output_path)
238
238
  self.util.change_filemode_safe(output_path)
239
- self.log.info("Compare result is saved in : %s", output_path)
239
+ self.log.info("Compare result is saved in : %s" % (output_path))
240
240
 
241
241
  def convert_api_dir_to_npy(self, dump_dir, param, output_dir, msaccucmp_path):
242
242
  dump_dir = self.util.path_strip(dump_dir)
243
- for root, _, files in os.walk(dump_dir, topdown=True):
244
- path_checker = FileChecker(root)
245
- path_checker.common_check()
246
- for file in files:
247
- file_path = os.path.join(root, file)
248
- file_name = os.path.basename(file_path)
249
- parts = file_name.split(".")
250
- if len(parts) < 5:
251
- continue
252
- op_name = parts[1]
253
- timestamp = parts[-1]
254
- output_path = os.path.join(output_dir, op_name, timestamp)
255
- self.convert_dump_to_npy(file_path, param, output_path, msaccucmp_path)
256
- path_depth = root.count(os.sep)
257
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
258
- yield root, _, files
259
- else:
260
- _[:] = []
243
+ files = os_walk_for_files(dump_dir, Const.MAX_TRAVERSAL_DEPTH)
244
+ filepaths = [os.path.join(file['root'], file['file']) for file in files]
245
+ for path in filepaths:
246
+ filename = os.path.basename(path)
247
+ parts = filename.split(".")
248
+ if len(parts) < 5:
249
+ continue
250
+ op_name = parts[1]
251
+ timestamp = parts[-1]
252
+ output_path = os.path.join(output_dir, op_name, timestamp)
253
+ self.convert_dump_to_npy(path, param, output_path, msaccucmp_path)
@@ -33,7 +33,7 @@ class Const:
33
33
  OFFLINE_DUMP_CONVERT_PATTERN = \
34
34
  r"^([A-Za-z0-9_-]+)\.([A-Za-z0-9_-]+)\.([0-9]+)(\.[0-9]+)?\.([0-9]{1,255})" \
35
35
  r"\.([a-z]+)\.([0-9]{1,255})(\.[x0-9]+)?\.npy$"
36
- NUMPY_PATTERN = r"^[\w\-_-]\.npy$"
36
+ NUMPY_PATTERN = r"^[\w\-_.]+\.npy$"
37
37
  NPY_SUFFIX = ".npy"
38
38
  PKL_SUFFIX = ".pkl"
39
39
  DIRECTORY_LENGTH = 4096
@@ -110,6 +110,9 @@ class ParseTool:
110
110
  parser.add_argument('-al', '--atol', dest='atol', default=0.001, type=float, help='set rtol')
111
111
  parser.add_argument('-rl', '--rtol', dest='rtol', default=0.001, type=float, help='set atol')
112
112
  args = parser.parse_args(argv)
113
+ self.util.check_positive(args.count)
114
+ self.util.check_positive(args.rtol)
115
+ self.util.check_positive(args.atol)
113
116
  self.util.check_path_valid(args.my_dump_path)
114
117
  self.util.check_path_valid(args.golden_dump_path)
115
118
  self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX)
@@ -129,8 +132,7 @@ class ParseTool:
129
132
  " '-m' and '-g'.")
130
133
  raise ParseException("My directory path and golden directory path is same.")
131
134
  output_path = self.util.path_strip(args.output_path) if args.output_path else Const.BATCH_COMPARE_DIR
132
- if not os.path.isdir(output_path):
133
- os.makedirs(output_path, mode=0o750)
135
+ create_directory(output_path)
134
136
  self.compare.compare_converted_dir(my_dump_dir, golden_dump_dir, output_path)
135
137
 
136
138
  @catch_exception
@@ -28,7 +28,7 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
28
28
  from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
29
29
  check_path_executable, check_path_owner_consistent
30
30
  from msprobe.core.common.const import FileCheckConst
31
- from msprobe.core.common.file_utils import FileChecker, check_file_or_directory_path, remove_path
31
+ from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type, os_walk_for_files
32
32
  from msprobe.pytorch.common.log import logger
33
33
 
34
34
 
@@ -71,31 +71,21 @@ class Util:
71
71
  check_path_executable(path)
72
72
 
73
73
  @staticmethod
74
- def get_subdir_count(self, directory):
74
+ def get_subdir_count(directory):
75
75
  subdir_count = 0
76
- path_checker = FileChecker(directory)
77
- path_checker.common_check()
76
+ check_file_or_directory_path(directory, isdir=True)
78
77
  for _, dirs, _ in os.walk(directory):
79
78
  subdir_count += len(dirs)
80
79
  break
81
80
  return subdir_count
82
81
 
83
82
  @staticmethod
84
- def get_subfiles_count(self, directory):
85
- file_count = 0
86
- for root, _, files in os.walk(directory, topdown=True):
87
- path_checker = FileChecker(root)
88
- path_checker.common_check()
89
- file_count += len(files)
90
- path_depth = root.count(os.sep)
91
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
92
- yield root, _, files
93
- else:
94
- _[:] = []
95
- return file_count
83
+ def get_subfiles_count(directory):
84
+ files = os_walk_for_files(directory, Const.MAX_TRAVERSAL_DEPTH)
85
+ return len(files)
96
86
 
97
87
  @staticmethod
98
- def get_sorted_subdirectories_names(self, directory):
88
+ def get_sorted_subdirectories_names(directory):
99
89
  subdirectories = []
100
90
  for item in os.listdir(directory):
101
91
  item_path = os.path.join(directory, item)
@@ -104,7 +94,7 @@ class Util:
104
94
  return sorted(subdirectories)
105
95
 
106
96
  @staticmethod
107
- def get_sorted_files_names(self, directory):
97
+ def get_sorted_files_names(directory):
108
98
  files = []
109
99
  for item in os.listdir(directory):
110
100
  item_path = os.path.join(directory, item)
@@ -113,7 +103,7 @@ class Util:
113
103
  return sorted(files)
114
104
 
115
105
  @staticmethod
116
- def check_npy_files_valid_in_dir(self, dir_path):
106
+ def check_npy_files_valid_in_dir(dir_path):
117
107
  for file_name in os.listdir(dir_path):
118
108
  file_path = os.path.join(dir_path, file_name)
119
109
  check_file_or_directory_path(file_path)
@@ -123,18 +113,18 @@ class Util:
123
113
  return True
124
114
 
125
115
  @staticmethod
126
- def get_md5_for_numpy(self, obj):
116
+ def get_md5_for_numpy(obj):
127
117
  np_bytes = obj.tobytes()
128
118
  md5_hash = hashlib.md5(np_bytes)
129
119
  return md5_hash.hexdigest()
130
120
 
131
121
  @staticmethod
132
- def deal_with_dir_or_file_inconsistency(self, output_path):
122
+ def deal_with_dir_or_file_inconsistency(output_path):
133
123
  remove_path(output_path)
134
124
  raise ParseException("Inconsistent directory structure or file.")
135
125
 
136
126
  @staticmethod
137
- def deal_with_value_if_has_zero(self, data):
127
+ def deal_with_value_if_has_zero(data):
138
128
  if data.dtype in Const.FLOAT_TYPE:
139
129
  zero_mask = (data == 0)
140
130
  # 给0的地方加上eps防止除0
@@ -147,26 +137,19 @@ class Util:
147
137
  return data
148
138
 
149
139
  @staticmethod
150
- def dir_contains_only(self, path, endfix):
151
- for root, _, files in os.walk(path, topdown=True):
152
- path_checker = FileChecker(root)
153
- path_checker.common_check()
154
- for file in files:
155
- if not file.endswith(endfix):
156
- return False
157
- path_depth = root.count(os.sep)
158
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
159
- yield root, _, files
160
- else:
161
- _[:] = []
140
+ def dir_contains_only(path, endfix):
141
+ files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
142
+ for file in files:
143
+ if not file['file'].endswith(endfix):
144
+ return False
162
145
  return True
163
146
 
164
147
  @staticmethod
165
- def localtime_str(self):
148
+ def localtime_str():
166
149
  return time.strftime("%Y%m%d%H%M%S", time.localtime())
167
150
 
168
151
  @staticmethod
169
- def change_filemode_safe(self, path):
152
+ def change_filemode_safe(path):
170
153
  change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
171
154
 
172
155
  @staticmethod
@@ -183,7 +166,7 @@ class Util:
183
166
  if not cmd:
184
167
  self.log.error("Commond is None")
185
168
  return -1
186
- self.log.info("[RUN CMD]: %s", cmd)
169
+ self.log.info("[RUN CMD]: %s" % cmd)
187
170
  cmd = cmd.split(" ")
188
171
  complete_process = subprocess.run(cmd, shell=False)
189
172
  return complete_process.returncode
@@ -205,7 +188,7 @@ class Util:
205
188
  result = subprocess.run(
206
189
  [self.python, target_file, "--help"], stdout=subprocess.PIPE, shell=False)
207
190
  if result.returncode == 0:
208
- self.log.info("Check [%s] success.", target_file)
191
+ self.log.info("Check [%s] success." % (target_file))
209
192
  else:
210
193
  self.log.error("Check msaccucmp failed in dir %s" % target_file)
211
194
  self.log.error("Please specify a valid msaccucmp.py path or install the cann package")
@@ -244,8 +227,11 @@ class Util:
244
227
 
245
228
  def check_path_valid(self, path):
246
229
  path = self.path_strip(path)
247
- path_checker = FileChecker(path)
248
- path_checker.common_check()
230
+ if not path or not os.path.exists(path):
231
+ self.log.error("The path %s does not exist." % path)
232
+ raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
233
+ isdir = check_file_type(path) == FileCheckConst.DIR
234
+ check_file_or_directory_path(path, isdir=isdir)
249
235
  return True
250
236
 
251
237
  def check_files_in_path(self, path):
@@ -273,21 +259,15 @@ class Util:
273
259
  self.check_path_valid(path)
274
260
  file_list = {}
275
261
  re_pattern = re.compile(pattern)
276
- for dir_path, _, file_names in os.walk(path, topdown=True):
277
- path_checker = FileChecker(dir)
278
- path_checker.common_check()
279
- for name in file_names:
280
- match = re_pattern.match(name)
281
- if not match:
282
- continue
283
- if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
284
- continue
285
- file_list[name] = gen_info_func(name, match, dir_path)
286
- path_depth = dir_path.count(os.sep)
287
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
288
- yield dir_path, _, file_names
289
- else:
290
- _[:] = []
262
+ files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
263
+ for file in files:
264
+ name = file["file"]
265
+ match = re_pattern.match(name)
266
+ if not match:
267
+ continue
268
+ if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
269
+ continue
270
+ file_list[name] = gen_info_func(name, match, file["root"])
291
271
  return file_list
292
272
 
293
273
  def check_file_path_format(self, path, suffix):
@@ -314,3 +294,8 @@ class Util:
314
294
  dir1_count = self.get_subdir_count(dir1)
315
295
  dir2_count = self.get_subdir_count(dir2)
316
296
  return dir1_count == dir2_count
297
+
298
+ def check_positive(self, value):
299
+ if value <= 0.0:
300
+ self.log.error("Invalid value. It must be greater than 0.")
301
+ raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR)
@@ -28,7 +28,7 @@ class Visualization:
28
28
  self.util = Util()
29
29
 
30
30
  def print_npy_summary(self, target_file):
31
- np_data = load_npy(target_file, enable_pickle=True)
31
+ np_data = load_npy(target_file)
32
32
  table = self.util.create_table('', ['Index', 'Data'])
33
33
  flatten_data = np_data.flatten()
34
34
  tablesize = 8
@@ -65,6 +65,8 @@ class Visualization:
65
65
  self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line))
66
66
  self.util.log.warning("Please check the pkl file")
67
67
  raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e
68
+ if not isinstance(msg, list) or len(msg) == 0:
69
+ break
68
70
  info_prefix = msg[0]
69
71
  if not info_prefix.startswith(api_name):
70
72
  continue
@@ -1,12 +1,35 @@
1
- import json
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import os
17
+ import re
3
18
 
4
- from msprobe.core.common_config import CommonConfig, BaseConfig
5
- from msprobe.core.common.file_utils import FileOpen
6
19
  from msprobe.core.common.const import Const
7
- from msprobe.pytorch.hook_module.utils import get_ops
20
+ from msprobe.core.common.exceptions import MsprobeException
21
+ from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.core.common.utils import is_int
24
+ from msprobe.core.common_config import BaseConfig, CommonConfig
8
25
  from msprobe.core.grad_probe.constant import level_adp
9
- from msprobe.core.grad_probe.utils import check_numeral_list_ascend
26
+ from msprobe.core.grad_probe.utils import check_bounds
27
+ from msprobe.pytorch.free_benchmark.common.enums import (
28
+ DeviceType,
29
+ HandlerType,
30
+ PytorchFreeBenchmarkConst,
31
+ )
32
+ from msprobe.pytorch.hook_module.utils import get_ops
10
33
 
11
34
 
12
35
  class TensorConfig(BaseConfig):
@@ -16,23 +39,39 @@ class TensorConfig(BaseConfig):
16
39
  self.nfs_path = json_config.get("nfs_path", "")
17
40
  self.host = json_config.get("host", "")
18
41
  self.port = json_config.get("port", -1)
19
- self.tls_path = json_config.get("tls_path", "")
42
+ self.tls_path = json_config.get("tls_path", "./")
43
+ self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
20
44
  self.check_config()
21
45
  self._check_file_format()
22
- self._check_tls_path_config()
46
+ if self.online_run_ut:
47
+ self._check_online_run_ut()
23
48
 
24
49
  def _check_file_format(self):
25
50
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
26
51
  raise Exception("file_format is invalid")
27
52
 
28
- def _check_tls_path_config(self):
53
+ def _check_online_run_ut(self):
54
+ if not isinstance(self.online_run_ut, bool):
55
+ raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
56
+
57
+ if not isinstance(self.online_run_ut_recompute, bool):
58
+ raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
59
+
60
+ if self.nfs_path:
61
+ check_file_or_directory_path(self.nfs_path, isdir=True)
62
+ return
63
+
29
64
  if self.tls_path:
30
- if not os.path.exists(self.tls_path):
31
- raise Exception("tls_path: %s does not exist" % self.tls_path)
32
- if not os.path.exists(os.path.join(self.tls_path, "client.key")):
33
- raise Exception("tls_path does not contain client.key")
34
- if not os.path.exists(os.path.join(self.tls_path, "client.crt")):
35
- raise Exception("tls_path does not contain client.crt")
65
+ check_file_or_directory_path(self.tls_path, isdir=True)
66
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
67
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
68
+ check_crt_valid(os.path.join(self.tls_path, "client.crt"))
69
+
70
+ if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
71
+ raise Exception(f"host: {self.host} is invalid.")
72
+
73
+ if not isinstance(self.port, int) or not (0 < self.port <= 65535):
74
+ raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
36
75
 
37
76
 
38
77
  class StatisticsConfig(BaseConfig):
@@ -54,30 +93,149 @@ class OverflowCheckConfig(BaseConfig):
54
93
  self.check_overflow_config()
55
94
 
56
95
  def check_overflow_config(self):
57
- if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
96
+ if self.overflow_nums is not None and not is_int(self.overflow_nums):
58
97
  raise Exception("overflow_num is invalid")
59
98
  if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
60
99
  raise Exception("check_mode is invalid")
61
100
 
62
101
 
63
102
  class FreeBenchmarkCheckConfig(BaseConfig):
103
+
64
104
  def __init__(self, json_config):
65
105
  super().__init__(json_config)
66
- self.fuzz_device = json_config.get("fuzz_device")
67
- self.pert_mode = json_config.get("pert_mode")
68
- self.handler_type = json_config.get("handler_type")
69
- self.fuzz_level = json_config.get("fuzz_level")
70
- self.fuzz_stage = json_config.get("fuzz_stage")
71
- self.if_preheat = json_config.get("if_preheat")
72
- self.preheat_step = json_config.get("preheat_step")
73
- self.max_sample = json_config.get("max_sample")
106
+ self.fuzz_device = json_config.get("fuzz_device", PytorchFreeBenchmarkConst.DEFAULT_DEVICE)
107
+ self.pert_mode = json_config.get("pert_mode", PytorchFreeBenchmarkConst.DEFAULT_MODE)
108
+ self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
109
+ self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
110
+ self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
111
+ self.if_preheat = json_config.get("if_preheat", False)
112
+ self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
113
+ self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
74
114
  self.check_freebenchmark_config()
75
115
 
76
116
  def check_freebenchmark_config(self):
77
- if self.if_preheat and self.handler_type == "fix":
78
- raise Exception("Preheating is not supported in fix handler type")
79
- if self.preheat_step and self.preheat_step == 0:
80
- raise Exception("preheat_step cannot be 0")
117
+ self._check_pert_mode()
118
+ self._check_fuzz_device()
119
+ self._check_handler_type()
120
+ self._check_fuzz_stage()
121
+ self._check_fuzz_level()
122
+ self._check_if_preheat()
123
+ if self.handler_type == HandlerType.FIX:
124
+ self._check_fix_config()
125
+ if self.if_preheat:
126
+ self._check_preheat_config()
127
+
128
+ def _check_pert_mode(self):
129
+ if self.pert_mode not in PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST:
130
+ msg = (
131
+ f"pert_mode is invalid, it should be one of"
132
+ f" {PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST}"
133
+ )
134
+ logger.error_log_with_exp(
135
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
136
+ )
137
+
138
+ def _check_fuzz_device(self):
139
+ if self.fuzz_device not in PytorchFreeBenchmarkConst.DEVICE_LIST:
140
+ msg = (
141
+ f"fuzz_device is invalid, it should be one of"
142
+ f" {PytorchFreeBenchmarkConst.DEVICE_LIST}"
143
+ )
144
+ logger.error_log_with_exp(
145
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
146
+ )
147
+ if (self.fuzz_device == DeviceType.CPU) ^ (
148
+ self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
149
+ ):
150
+ msg = (
151
+ f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
152
+ f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
153
+ )
154
+ logger.error_log_with_exp(
155
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
156
+ )
157
+
158
+ def _check_handler_type(self):
159
+ if self.handler_type not in PytorchFreeBenchmarkConst.HANDLER_LIST:
160
+ msg = (
161
+ f"handler_type is invalid, it should be one of"
162
+ f" {PytorchFreeBenchmarkConst.HANDLER_LIST}"
163
+ )
164
+ logger.error_log_with_exp(
165
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
166
+ )
167
+
168
+ def _check_fuzz_stage(self):
169
+ if self.fuzz_stage not in PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST:
170
+ msg = (
171
+ f"fuzz_stage is invalid, it should be one of"
172
+ f" {PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST}"
173
+ )
174
+ logger.error_log_with_exp(
175
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
176
+ )
177
+
178
+ def _check_fuzz_level(self):
179
+ if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
180
+ msg = (
181
+ f"fuzz_level is invalid, it should be one of"
182
+ f" {PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST}"
183
+ )
184
+ logger.error_log_with_exp(
185
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
186
+ )
187
+
188
+ def _check_if_preheat(self):
189
+ if not isinstance(self.if_preheat, bool):
190
+ msg = "if_preheat is invalid, it should be a boolean"
191
+ logger.error_log_with_exp(
192
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
193
+ )
194
+
195
+ def _check_preheat_config(self):
196
+ if not is_int(self.preheat_step):
197
+ msg = "preheat_step is invalid, it should be an integer"
198
+ logger.error_log_with_exp(
199
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
200
+ )
201
+ if self.preheat_step <= 0:
202
+ msg = "preheat_step must be greater than 0"
203
+ logger.error_log_with_exp(
204
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
205
+ )
206
+ if not is_int(self.max_sample):
207
+ msg = "max_sample is invalid, it should be an integer"
208
+ logger.error_log_with_exp(
209
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
210
+ )
211
+ if self.max_sample <= 0:
212
+ msg = "max_sample must be greater than 0"
213
+ logger.error_log_with_exp(
214
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
215
+ )
216
+
217
+ def _check_fix_config(self):
218
+ if self.if_preheat:
219
+ msg = f"Preheating is not supported for {HandlerType.FIX} handler type"
220
+ logger.error_log_with_exp(
221
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
222
+ )
223
+ if self.fuzz_stage not in PytorchFreeBenchmarkConst.FIX_STAGE_LIST:
224
+ msg = (
225
+ f"The fuzz_stage when opening {HandlerType.FIX} handler must be one of "
226
+ f"{PytorchFreeBenchmarkConst.FIX_STAGE_LIST}"
227
+ )
228
+ logger.error_log_with_exp(
229
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
230
+ )
231
+ if self.pert_mode not in PytorchFreeBenchmarkConst.FIX_MODE_LIST:
232
+ msg = (
233
+ f"The pert_mode when opening {HandlerType.FIX} handler must be one of "
234
+ f"{PytorchFreeBenchmarkConst.FIX_MODE_LIST}"
235
+ )
236
+ logger.error_log_with_exp(
237
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
238
+ )
81
239
 
82
240
 
83
241
  class RunUTConfig(BaseConfig):
@@ -93,7 +251,7 @@ class RunUTConfig(BaseConfig):
93
251
  self.host = json_config.get("host", "")
94
252
  self.port = json_config.get("port", -1)
95
253
  self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
96
- self.tls_path = json_config.get("tls_path", "")
254
+ self.tls_path = json_config.get("tls_path", "./")
97
255
  self.check_run_ut_config()
98
256
 
99
257
  @classmethod
@@ -118,13 +276,8 @@ class RunUTConfig(BaseConfig):
118
276
 
119
277
  @classmethod
120
278
  def check_tls_path_config(cls, tls_path):
121
- if tls_path:
122
- if not os.path.exists(tls_path):
123
- raise Exception("tls_path: %s does not exist" % tls_path)
124
- if not os.path.exists(os.path.join(tls_path, "server.key")):
125
- raise Exception("tls_path does not contain server.key")
126
- if not os.path.exists(os.path.join(tls_path, "server.crt")):
127
- raise Exception("tls_path does not contain server.crt")
279
+ if tls_path and not os.path.exists(tls_path):
280
+ raise Exception("tls_path: %s does not exist" % tls_path)
128
281
 
129
282
  def check_run_ut_config(self):
130
283
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
@@ -141,13 +294,13 @@ class GradToolConfig(BaseConfig):
141
294
  self.param_list = json_config.get("param_list", [])
142
295
  self.bounds = json_config.get("bounds", [-1, 0, 1])
143
296
  self._check_config()
144
-
297
+
145
298
  def _check_config(self):
146
299
  if self.grad_level not in level_adp.keys():
147
300
  raise Exception(f"grad_level must be one of {level_adp.keys()}")
148
301
  if not isinstance(self.param_list, list):
149
302
  raise Exception(f"param_list must be a list")
150
- check_numeral_list_ascend(self.bounds)
303
+ check_bounds(self.bounds)
151
304
 
152
305
 
153
306
  def parse_task_config(task, json_config):
@@ -178,10 +331,9 @@ def parse_json_config(json_file_path, task):
178
331
  if not json_file_path:
179
332
  config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
180
333
  json_file_path = os.path.join(config_dir, "config.json")
181
- with FileOpen(json_file_path, 'r') as file:
182
- json_config = json.load(file)
334
+ json_config = load_json(json_file_path)
183
335
  common_config = CommonConfig(json_config)
184
- if task and task in Const.TASK_LIST:
336
+ if task:
185
337
  task_config = parse_task_config(task, json_config)
186
338
  else:
187
339
  task_config = parse_task_config(common_config.task, json_config)