mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -22,6 +22,8 @@ from collections import namedtuple
22
22
  from msprobe.pytorch.parse_tool.lib.utils import Util
23
23
  from msprobe.pytorch.parse_tool.lib.config import Const
24
24
  from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
25
+ from msprobe.core.common.utils import create_directory, write_csv, save_npy_to_txt
26
+ from msprobe.core.common.file_check import FileChecker
25
27
 
26
28
 
27
29
  class Compare:
@@ -36,7 +38,7 @@ class Compare:
36
38
  self.log.info("Compare finished!!")
37
39
 
38
40
  def compare_vector(self, my_dump_path, golden_dump_path, result_dir, msaccucmp_path):
39
- self.util.create_dir(result_dir)
41
+ create_directory(result_dir)
40
42
  self.util.check_path_valid(result_dir)
41
43
  call_msaccucmp = self.util.check_msaccucmp(msaccucmp_path)
42
44
  cmd = '%s %s compare -m %s -g %s -out %s' % (
@@ -65,7 +67,7 @@ class Compare:
65
67
  self.util.print_panel("\n".join(summary_txt))
66
68
 
67
69
  def convert(self, dump_file, data_format, output, msaccucmp_path):
68
- self.util.create_dir(output)
70
+ create_directory(output)
69
71
  self.util.check_path_valid(output)
70
72
  call_msaccucmp = self.util.check_msaccucmp(msaccucmp_path)
71
73
  if data_format:
@@ -83,21 +85,22 @@ class Compare:
83
85
  (left, right, save_txt, rl, al, diff_count) = args
84
86
  if left is None or right is None:
85
87
  raise ParseException("invalid input or output")
86
- try:
87
- left_data = np.load(left)
88
- right_data = np.load(right)
89
- except UnicodeError as e:
90
- self.log.error("%s %s" % ("UnicodeError", str(e)))
91
- self.log.warning("Please check the npy file")
92
- raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e
93
- except IOError:
94
- self.log.error("Failed to load npy %s or %s." % (left, right))
95
- raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e
88
+ if self.util.check_path_valid(left) and self.util.check_path_valid(right):
89
+ try:
90
+ left_data = np.load(left)
91
+ right_data = np.load(right)
92
+ except UnicodeError as e:
93
+ self.log.error("%s %s" % ("UnicodeError", str(e)))
94
+ self.log.warning("Please check the npy file")
95
+ raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e
96
+ except IOError:
97
+ self.log.error("Failed to load npy %s or %s." % (left, right))
98
+ raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e
96
99
 
97
100
  # save to txt
98
101
  if save_txt:
99
- self.util.save_npy_to_txt(left_data, left + ".txt")
100
- self.util.save_npy_to_txt(right_data, right + ".txt")
102
+ save_npy_to_txt(left_data, left + ".txt")
103
+ save_npy_to_txt(right_data, right + ".txt")
101
104
  # compare data
102
105
  (total_cnt, all_close, cos_sim, err_percent) = self.do_compare_data(left_data, right_data, rl, al, diff_count)
103
106
  content = ['Left:', ' ├─ NpyFile: %s' % left]
@@ -157,8 +160,10 @@ class Compare:
157
160
  return res
158
161
 
159
162
  def compare_npy(self, file, bench_file, output_path):
160
- data = np.load(file)
161
- bench_data = np.load(bench_file)
163
+ if self.util.check_path_valid(file):
164
+ data = np.load(file)
165
+ if self.util.check_path_valid(bench_file):
166
+ bench_data = np.load(bench_file)
162
167
  shape, dtype = data.shape, data.dtype
163
168
  bench_shape, bench_dtype = bench_data.shape, bench_data.dtype
164
169
  filename = os.path.basename(file)
@@ -181,7 +186,7 @@ class Compare:
181
186
  rel_diff_max = np.max(rel_error)
182
187
  compare_result = [[filename, bench_filename, data_mean, bench_data_mean, md5_consistency, abs_diff_max,
183
188
  rel_diff_max]]
184
- self.util.write_csv(compare_result, output_path)
189
+ write_csv(compare_result, output_path)
185
190
 
186
191
  def compare_all_file_in_directory(self, my_dump_dir, golden_dump_dir, output_path):
187
192
  if not (self.util.is_subdir_count_equal(my_dump_dir, golden_dump_dir)
@@ -228,7 +233,7 @@ class Compare:
228
233
  "Max Abs Error",
229
234
  "Max Relative Error"
230
235
  ]]
231
- self.util.write_csv(title_rows, output_path)
236
+ write_csv(title_rows, output_path)
232
237
 
233
238
  my_ordered_subdirs = self.util.get_sorted_subdirectories_names(my_dump_dir)
234
239
  golden_ordered_subdirs = self.util.get_sorted_subdirectories_names(golden_dump_dir)
@@ -246,7 +251,9 @@ class Compare:
246
251
 
247
252
  def convert_api_dir_to_npy(self, dump_dir, param, output_dir, msaccucmp_path):
248
253
  dump_dir = self.util.path_strip(dump_dir)
249
- for root, _, files in os.walk(dump_dir):
254
+ for root, _, files in os.walk(dump_dir, topdown=True):
255
+ path_checker = FileChecker(root)
256
+ path_checker.common_check()
250
257
  for file in files:
251
258
  file_path = os.path.join(root, file)
252
259
  file_name = os.path.basename(file_path)
@@ -257,3 +264,8 @@ class Compare:
257
264
  timestamp = parts[-1]
258
265
  output_path = os.path.join(output_dir, op_name, timestamp)
259
266
  self.convert_dump_to_npy(file_path, param, output_path, msaccucmp_path)
267
+ path_depth = root.count(os.sep)
268
+ if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
269
+ yield root, _, files
270
+ else:
271
+ _[:] = []
@@ -33,11 +33,12 @@ class Const:
33
33
  OFFLINE_DUMP_CONVERT_PATTERN = \
34
34
  r"^([A-Za-z0-9_-]+)\.([A-Za-z0-9_-]+)\.([0-9]+)(\.[0-9]+)?\.([0-9]{1,255})" \
35
35
  r"\.([a-z]+)\.([0-9]{1,255})(\.[x0-9]+)?\.npy$"
36
- NUMPY_PATTERN = r".*\.npy$"
36
+ NUMPY_PATTERN = r"^[\w\-_-]\.npy$"
37
37
  NPY_SUFFIX = ".npy"
38
38
  PKL_SUFFIX = ".pkl"
39
39
  DIRECTORY_LENGTH = 4096
40
40
  FILE_NAME_LENGTH = 255
41
+ MAX_TRAVERSAL_DEPTH = 5
41
42
  FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
42
43
  ONE_GB = 1 * 1024 * 1024 * 1024
43
44
  TEN_GB = 10 * 1024 * 1024 * 1024
@@ -23,7 +23,7 @@ from msprobe.pytorch.parse_tool.lib.utils import Util
23
23
  from msprobe.pytorch.parse_tool.lib.compare import Compare
24
24
  from msprobe.pytorch.parse_tool.lib.visualization import Visualization
25
25
  from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
26
-
26
+ from msprobe.core.common.utils import create_directory
27
27
 
28
28
  class ParseTool:
29
29
  def __init__(self):
@@ -33,7 +33,7 @@ class ParseTool:
33
33
 
34
34
  @catch_exception
35
35
  def prepare(self):
36
- self.util.create_dir(Const.DATA_ROOT_DIR)
36
+ create_directory(Const.DATA_ROOT_DIR)
37
37
 
38
38
  @catch_exception
39
39
  def do_vector_compare(self, args):
@@ -112,8 +112,8 @@ class ParseTool:
112
112
  args = parser.parse_args(argv)
113
113
  self.util.check_path_valid(args.my_dump_path)
114
114
  self.util.check_path_valid(args.golden_dump_path)
115
- self.util.check_path_format(args.my_dump_path, Const.NPY_SUFFIX)
116
- self.util.check_path_format(args.golden_dump_path, Const.NPY_SUFFIX)
115
+ self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX)
116
+ self.util.check_file_path_format(args.golden_dump_path, Const.NPY_SUFFIX)
117
117
  compare_data_args = namedtuple('compare_data_args', ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count'])
118
118
  compare_data_args.__new__.__defaults__ = (False, 0.001, 0.001, 20)
119
119
  res = compare_data_args(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count)
@@ -31,7 +31,7 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
31
31
  from msprobe.core.common.file_check import change_mode, check_other_user_writable,\
32
32
  check_path_executable, check_path_owner_consistent
33
33
  from msprobe.core.common.const import FileCheckConst
34
- from msprobe.core.common.file_check import FileOpen
34
+ from msprobe.core.common.file_check import FileOpen, FileChecker
35
35
  from msprobe.core.common.utils import check_file_or_directory_path
36
36
  from msprobe.pytorch.common.log import logger
37
37
 
@@ -57,12 +57,7 @@ except ImportError as err:
57
57
  class Util:
58
58
  def __init__(self):
59
59
  self.ms_accu_cmp = None
60
- logging.basicConfig(
61
- level=Const.LOG_LEVEL,
62
- format="%(asctime)s (%(process)d) -[%(levelname)s]%(message)s",
63
- datefmt="%Y-%m-%d %H:%M:%S"
64
- )
65
- self.log = logging.getLogger()
60
+ self.log = logger
66
61
  self.python = sys.executable
67
62
 
68
63
  @staticmethod
@@ -82,6 +77,8 @@ class Util:
82
77
  @staticmethod
83
78
  def get_subdir_count(self, directory):
84
79
  subdir_count = 0
80
+ path_checker = FileChecker(directory)
81
+ path_checker.common_check()
85
82
  for _, dirs, _ in os.walk(directory):
86
83
  subdir_count += len(dirs)
87
84
  break
@@ -90,8 +87,15 @@ class Util:
90
87
  @staticmethod
91
88
  def get_subfiles_count(self, directory):
92
89
  file_count = 0
93
- for _, _, files in os.walk(directory):
90
+ for root, _, files in os.walk(directory, topdown=True):
91
+ path_checker = FileChecker(root)
92
+ path_checker.common_check()
94
93
  file_count += len(files)
94
+ path_depth = root.count(os.sep)
95
+ if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
96
+ yield root, _, files
97
+ else:
98
+ _[:] = []
95
99
  return file_count
96
100
 
97
101
  @staticmethod
@@ -128,17 +132,6 @@ class Util:
128
132
  md5_hash = hashlib.md5(np_bytes)
129
133
  return md5_hash.hexdigest()
130
134
 
131
- @staticmethod
132
- def write_csv(self, data, filepath):
133
- need_change_mode = False
134
- if not os.path.exists(filepath):
135
- need_change_mode = True
136
- with FileOpen(filepath, 'a') as f:
137
- writer = csv.writer(f)
138
- writer.writerows(data)
139
- if need_change_mode:
140
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
141
-
142
135
  @staticmethod
143
136
  def deal_with_dir_or_file_inconsistency(self, output_path):
144
137
  if os.path.exists(output_path):
@@ -160,10 +153,17 @@ class Util:
160
153
 
161
154
  @staticmethod
162
155
  def dir_contains_only(self, path, endfix):
163
- for _, _, files in os.walk(path):
156
+ for root, _, files in os.walk(path, topdown=True):
157
+ path_checker = FileChecker(root)
158
+ path_checker.common_check()
164
159
  for file in files:
165
160
  if not file.endswith(endfix):
166
161
  return False
162
+ path_depth = root.count(os.sep)
163
+ if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
164
+ yield root, _, files
165
+ else:
166
+ _[:] = []
167
167
  return True
168
168
 
169
169
  @staticmethod
@@ -188,7 +188,7 @@ class Util:
188
188
  if not cmd:
189
189
  self.log.error("Commond is None")
190
190
  return -1
191
- self.log.debug("[RUN CMD]: %s", cmd)
191
+ self.log.info("[RUN CMD]: %s", cmd)
192
192
  cmd = cmd.split(" ")
193
193
  complete_process = subprocess.run(cmd, shell=False)
194
194
  return complete_process.returncode
@@ -208,7 +208,7 @@ class Util:
208
208
  "Check msaccucmp failed in dir %s. This is not a correct msaccucmp file" % target_file)
209
209
  raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR)
210
210
  result = subprocess.run(
211
- [self.python, target_file, "--help"], stdout=subprocess.PIPE)
211
+ [self.python, target_file, "--help"], stdout=subprocess.PIPE, shell=False)
212
212
  if result.returncode == 0:
213
213
  self.log.info("Check [%s] success.", target_file)
214
214
  else:
@@ -217,37 +217,12 @@ class Util:
217
217
  raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR)
218
218
  return target_file
219
219
 
220
- def create_dir(self, path):
221
- path = self.path_strip(path)
222
- if os.path.exists(path):
223
- return
224
- self.check_path_name(path)
225
- try:
226
- os.makedirs(path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
227
- except OSError as e:
228
- self.log.error("Failed to create %s.", path)
229
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) from e
230
-
231
220
  def gen_npy_info_txt(self, source_data):
232
221
  (shape, dtype, max_data, min_data, mean) = \
233
222
  self.npy_info(source_data)
234
223
  return \
235
224
  '[Shape: %s] [Dtype: %s] [Max: %s] [Min: %s] [Mean: %s]' % (shape, dtype, max_data, min_data, mean)
236
225
 
237
- def save_npy_to_txt(self, data, dst_file='', align=0):
238
- if os.path.exists(dst_file):
239
- self.log.info("Dst file %s exists, will not save new one.", dst_file)
240
- return
241
- shape = data.shape
242
- data = data.flatten()
243
- if align == 0:
244
- align = 1 if len(shape) == 0 else shape[-1]
245
- elif data.size % align != 0:
246
- pad_array = np.zeros((align - data.size % align,))
247
- data = np.append(data, pad_array)
248
- np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
249
- change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
250
-
251
226
  def list_convert_files(self, path, external_pattern=""):
252
227
  return self.list_file_with_pattern(
253
228
  path, Const.OFFLINE_DUMP_CONVERT_PATTERN, external_pattern, self._gen_npu_dump_convert_file_info
@@ -274,27 +249,8 @@ class Util:
274
249
 
275
250
  def check_path_valid(self, path):
276
251
  path = self.path_strip(path)
277
- if not path or not os.path.exists(path):
278
- self.log.error("The path %s does not exist." % path)
279
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
280
- if os.path.islink(path):
281
- self.log.error('The file path {} is a soft link.'.format(path))
282
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
283
- if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
284
- Const.FILE_NAME_LENGTH:
285
- self.log.error('The file path length exceeds limit.')
286
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
287
- if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
288
- self.log.error('The file path {} contains special characters.'.format(path))
289
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
290
- if os.path.isfile(path):
291
- file_size = os.path.getsize(path)
292
- if path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
293
- self.log.error('The file {} size is greater than 1GB.'.format(path))
294
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
295
- if path.endswith(Const.NPY_SUFFIX) and file_size > Const.TEN_GB:
296
- self.log.error('The file {} size is greater than 10GB.'.format(path))
297
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
252
+ path_checker = FileChecker(path)
253
+ path_checker.common_check()
298
254
  return True
299
255
 
300
256
  def check_files_in_path(self, path):
@@ -322,17 +278,24 @@ class Util:
322
278
  self.check_path_valid(path)
323
279
  file_list = {}
324
280
  re_pattern = re.compile(pattern)
325
- for dir_path, _, file_names in os.walk(path, followlinks=True):
281
+ for dir_path, _, file_names in os.walk(path, topdown=True):
282
+ path_checker = FileChecker(dir)
283
+ path_checker.common_check()
326
284
  for name in file_names:
327
285
  match = re_pattern.match(name)
328
286
  if not match:
329
287
  continue
330
- if extern_pattern != '' and not re.match(extern_pattern, name):
288
+ if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
331
289
  continue
332
290
  file_list[name] = gen_info_func(name, match, dir_path)
291
+ path_depth = dir_path.count(os.sep)
292
+ if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
293
+ yield dir_path, _, file_names
294
+ else:
295
+ _[:] = []
333
296
  return file_list
334
297
 
335
- def check_path_format(self, path, suffix):
298
+ def check_file_path_format(self, path, suffix):
336
299
  if os.path.isfile(path):
337
300
  if not path.endswith(suffix):
338
301
  self.log.error("%s is not a %s file." % (path, suffix))
@@ -344,15 +307,6 @@ class Util:
344
307
  self.log.error("The file path %s is invalid" % path)
345
308
  raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
346
309
 
347
- def check_path_name(self, path):
348
- if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
349
- Const.FILE_NAME_LENGTH:
350
- self.log.error('The file path length exceeds limit.')
351
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
352
- if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
353
- self.log.error('The file path {} contains special characters.'.format(path))
354
- raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
355
-
356
310
  def check_str_param(self, param):
357
311
  if len(param) > Const.FILE_NAME_LENGTH:
358
312
  self.log.error('The parameter length exceeds limit')
@@ -21,6 +21,7 @@ from msprobe.pytorch.parse_tool.lib.config import Const
21
21
  from msprobe.pytorch.parse_tool.lib.utils import Util
22
22
  from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
23
23
  from msprobe.core.common.file_check import FileOpen
24
+ from msprobe.core.common.utils import save_npy_to_txt
24
25
 
25
26
 
26
27
  class Visualization:
@@ -43,18 +44,18 @@ class Visualization:
43
44
  summary = ['[yellow]%s[/yellow]' % self.util.gen_npy_info_txt(np_data), 'Path: %s' % target_file,
44
45
  "TextFile: %s.txt" % target_file]
45
46
  self.util.print_panel(self.util.create_columns([table, "\n".join(summary)]), target_file)
46
- self.util.save_npy_to_txt(np_data, target_file + ".txt")
47
+ save_npy_to_txt(np_data, target_file + ".txt")
47
48
 
48
49
  def print_npy_data(self, file_name):
49
50
  file_name = self.util.path_strip(file_name)
50
51
  self.util.check_path_valid(file_name)
51
- self.util.check_path_format(file_name, Const.NPY_SUFFIX)
52
+ self.util.check_file_path_format(file_name, Const.NPY_SUFFIX)
52
53
  return self.print_npy_summary(file_name)
53
54
 
54
55
  def parse_pkl(self, path, api_name):
55
56
  path = self.util.path_strip(path)
56
57
  self.util.check_path_valid(path)
57
- self.util.check_path_format(path, Const.PKL_SUFFIX)
58
+ self.util.check_file_path_format(path, Const.PKL_SUFFIX)
58
59
  self.util.check_str_param(api_name)
59
60
  with FileOpen(path, "r") as pkl_handle:
60
61
  title_printed = False
@@ -4,18 +4,36 @@ import os
4
4
  from msprobe.core.common_config import CommonConfig, BaseConfig
5
5
  from msprobe.core.common.file_check import FileOpen
6
6
  from msprobe.core.common.const import Const
7
+ from msprobe.pytorch.hook_module.utils import get_ops
8
+ from msprobe.core.grad_probe.constant import level_adp
9
+ from msprobe.core.grad_probe.utils import check_numeral_list_ascend
7
10
 
8
11
 
9
12
  class TensorConfig(BaseConfig):
10
13
  def __init__(self, json_config):
11
14
  super().__init__(json_config)
15
+ self.online_run_ut = json_config.get("online_run_ut", False)
16
+ self.nfs_path = json_config.get("nfs_path", "")
17
+ self.host = json_config.get("host", "")
18
+ self.port = json_config.get("port", -1)
19
+ self.tls_path = json_config.get("tls_path", "")
12
20
  self.check_config()
13
21
  self._check_file_format()
22
+ self._check_tls_path_config()
14
23
 
15
24
  def _check_file_format(self):
16
25
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
17
26
  raise Exception("file_format is invalid")
18
27
 
28
+ def _check_tls_path_config(self):
29
+ if self.tls_path:
30
+ if not os.path.exists(self.tls_path):
31
+ raise Exception("tls_path: %s does not exist" % self.tls_path)
32
+ if not os.path.exists(os.path.join(self.tls_path, "client.key")):
33
+ raise Exception("tls_path does not contain client.key")
34
+ if not os.path.exists(os.path.join(self.tls_path, "client.crt")):
35
+ raise Exception("tls_path does not contain client.crt")
36
+
19
37
 
20
38
  class StatisticsConfig(BaseConfig):
21
39
  def __init__(self, json_config):
@@ -31,12 +49,12 @@ class StatisticsConfig(BaseConfig):
31
49
  class OverflowCheckConfig(BaseConfig):
32
50
  def __init__(self, json_config):
33
51
  super().__init__(json_config)
34
- self.overflow_num = json_config.get("overflow_nums")
52
+ self.overflow_nums = json_config.get("overflow_nums")
35
53
  self.check_mode = json_config.get("check_mode")
36
54
  self.check_overflow_config()
37
55
 
38
56
  def check_overflow_config(self):
39
- if self.overflow_num is not None and not isinstance(self.overflow_num, int):
57
+ if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
40
58
  raise Exception("overflow_num is invalid")
41
59
  if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
42
60
  raise Exception("check_mode is invalid")
@@ -61,20 +79,96 @@ class FreeBenchmarkCheckConfig(BaseConfig):
61
79
  if self.preheat_step and self.preheat_step == 0:
62
80
  raise Exception("preheat_step cannot be 0")
63
81
 
82
+
83
+ class RunUTConfig(BaseConfig):
84
+ WrapApi = get_ops()
85
+
86
+ def __init__(self, json_config):
87
+ super().__init__(json_config)
88
+ self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
89
+ self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
90
+ self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
91
+ self.is_online = json_config.get("is_online", False)
92
+ self.nfs_path = json_config.get("nfs_path", "")
93
+ self.host = json_config.get("host", "")
94
+ self.port = json_config.get("port", -1)
95
+ self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
96
+ self.tls_path = json_config.get("tls_path", "")
97
+ self.check_run_ut_config()
98
+
99
+ @classmethod
100
+ def check_filter_list_config(cls, key, filter_list):
101
+ if not isinstance(filter_list, list):
102
+ raise Exception("%s must be a list type" % key)
103
+ if not all(isinstance(item, str) for item in filter_list):
104
+ raise Exception("All elements in %s must be string type" % key)
105
+ invalid_api = [item for item in filter_list if item not in cls.WrapApi]
106
+ if invalid_api:
107
+ raise Exception("Invalid api in %s: %s" % (key, invalid_api))
108
+
109
+ @classmethod
110
+ def check_error_data_path_config(cls, error_data_path):
111
+ if not os.path.exists(error_data_path):
112
+ raise Exception("error_data_path: %s does not exist" % error_data_path)
113
+
114
+ @classmethod
115
+ def check_nfs_path_config(cls, nfs_path):
116
+ if nfs_path and not os.path.exists(nfs_path):
117
+ raise Exception("nfs_path: %s does not exist" % nfs_path)
118
+
119
+ @classmethod
120
+ def check_tls_path_config(cls, tls_path):
121
+ if tls_path:
122
+ if not os.path.exists(tls_path):
123
+ raise Exception("tls_path: %s does not exist" % tls_path)
124
+ if not os.path.exists(os.path.join(tls_path, "server.key")):
125
+ raise Exception("tls_path does not contain server.key")
126
+ if not os.path.exists(os.path.join(tls_path, "server.crt")):
127
+ raise Exception("tls_path does not contain server.crt")
128
+
129
+ def check_run_ut_config(self):
130
+ RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
131
+ RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
132
+ RunUTConfig.check_error_data_path_config(self.error_data_path)
133
+ RunUTConfig.check_nfs_path_config(self.nfs_path)
134
+ RunUTConfig.check_tls_path_config(self.tls_path)
135
+
136
+
137
+ class GradToolConfig(BaseConfig):
138
+ def __init__(self, json_config):
139
+ super().__init__(json_config)
140
+ self.grad_level = json_config.get("grad_level", "L1")
141
+ self.param_list = json_config.get("param_list", [])
142
+ self.bounds = json_config.get("bounds", [])
143
+
144
+ def _check_config(self):
145
+ if self.grad_level not in level_adp.keys():
146
+ raise Exception(f"grad_level must be one of {level_adp.keys()}")
147
+ if not isinstance(self.param_list, list):
148
+ raise Exception(f"param_list must be a list")
149
+ check_numeral_list_ascend(self.bounds)
150
+
151
+
64
152
  def parse_task_config(task, json_config):
65
153
  default_dic = {}
66
154
  if task == Const.TENSOR:
67
- config_dic = json_config.get(Const.TENSOR) if json_config.get(Const.TENSOR) else default_dic
155
+ config_dic = json_config.get(Const.TENSOR, default_dic)
68
156
  return TensorConfig(config_dic)
69
157
  elif task == Const.STATISTICS:
70
- config_dic = json_config.get(Const.STATISTICS) if json_config.get(Const.STATISTICS) else default_dic
158
+ config_dic = json_config.get(Const.STATISTICS, default_dic)
71
159
  return StatisticsConfig(config_dic)
72
160
  elif task == Const.OVERFLOW_CHECK:
73
- config_dic = json_config.get(Const.OVERFLOW_CHECK) if json_config.get(Const.OVERFLOW_CHECK) else default_dic
161
+ config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic)
74
162
  return OverflowCheckConfig(config_dic)
75
163
  elif task == Const.FREE_BENCHMARK:
76
- config_dic = json_config.get(Const.FREE_BENCHMARK) if json_config.get(Const.FREE_BENCHMARK) else default_dic
164
+ config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic)
77
165
  return FreeBenchmarkCheckConfig(config_dic)
166
+ elif task == Const.RUN_UT:
167
+ config_dic = json_config.get(Const.RUN_UT, default_dic)
168
+ return RunUTConfig(config_dic)
169
+ elif task == Const.GRAD_PROBE:
170
+ config_dic = json_config.get(Const.GRAD_PROBE, default_dic)
171
+ return GradToolConfig(config_dic)
78
172
  else:
79
173
  return StatisticsConfig(default_dic)
80
174