mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,116 +1,96 @@
1
- import os
2
- import csv
3
- import fcntl
4
- import json
5
- from pathlib import Path
6
-
7
- from msprobe.core.common.file_check import change_mode
8
- from msprobe.core.common.log import logger
9
- from msprobe.core.common.const import Const, FileCheckConst
10
-
11
-
12
- class DataWriter:
13
-
14
- def __init__(self, init_json=None) -> None:
15
- self.dump_count = 0
16
- self.init_json = init_json
17
- self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
18
- self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
19
- self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
20
- self.free_benchmark_file_path = None
21
- self.dump_tensor_data_dir = None
22
- self.buffer_size = 1000
23
- self.cache_data = {Const.DATA: {}}
24
- self.cache_stack = {}
25
- self.cache_construct = {}
26
-
27
- @staticmethod
28
- def write_data_to_csv(result: list, result_header: tuple, file_path: str):
29
- if not result:
30
- return
31
- is_exists = os.path.exists(file_path)
32
- append = "a+" if is_exists else "w+"
33
- with os.fdopen(
34
- os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline=""
35
- ) as csv_file:
36
- spawn_writer = csv.writer(csv_file)
37
- if not is_exists:
38
- spawn_writer.writerow(result_header)
39
- spawn_writer.writerows([result,])
40
-
41
- def initialize_json_file(self, **kwargs):
42
- kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
43
- with os.fdopen(
44
- os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w'
45
- ) as f:
46
- json.dump(kwargs, f)
47
-
48
- if os.path.exists(self.stack_file_path):
49
- os.remove(self.stack_file_path)
50
- Path(self.stack_file_path).touch()
51
- change_mode(self.stack_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
52
-
53
- if os.path.exists(self.construct_file_path):
54
- os.remove(self.construct_file_path)
55
- Path(self.construct_file_path).touch()
56
- change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
57
-
58
- def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
59
- free_benchmark_file_path):
60
- self.dump_file_path = dump_file_path
61
- self.stack_file_path = stack_file_path
62
- self.construct_file_path = construct_file_path
63
- self.dump_tensor_data_dir = dump_data_dir
64
- self.free_benchmark_file_path = free_benchmark_file_path
65
-
66
- def update_data(self, new_data):
67
- key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
68
- if key in self.cache_data[Const.DATA]:
69
- self.cache_data[Const.DATA][key].update(new_data[key])
70
- else:
71
- self.cache_data[Const.DATA].update(new_data)
72
-
73
- def flush_data_when_buffer_is_full(self):
74
- if len(self.cache_data[Const.DATA]) >= self.buffer_size:
75
- self.write_data_json(self.dump_file_path)
76
-
77
- def update_stack(self, new_data):
78
- self.cache_stack.update(new_data)
79
-
80
- def update_construct(self, new_data):
81
- self.cache_construct.update(new_data)
82
-
83
- def write_data_json(self, file_path):
84
- logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
85
- if Path(file_path).exists() and os.path.getsize(file_path) > 0:
86
- with open(file_path, "r+") as f:
87
- fcntl.flock(f, fcntl.LOCK_EX)
88
- data_to_write = json.load(f)
89
- fcntl.flock(f, fcntl.LOCK_UN)
90
- else:
91
- self.init_json['data_path'] = self.dump_tensor_data_dir
92
- data_to_write = self.init_json
93
- data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
94
- with open(file_path, 'w+') as f:
95
- fcntl.flock(f, fcntl.LOCK_EX)
96
- json.dump(data_to_write, f, indent=1)
97
- fcntl.flock(f, fcntl.LOCK_UN)
98
-
99
- self.cache_data[Const.DATA].clear()
100
-
101
- def write_stack_info_json(self, file_path):
102
- with open(file_path, 'w+') as f:
103
- fcntl.flock(f, fcntl.LOCK_EX)
104
- json.dump(self.cache_stack, f, indent=1)
105
- fcntl.flock(f, fcntl.LOCK_UN)
106
-
107
- def write_construct_info_json(self, file_path):
108
- with open(file_path, 'w+') as f:
109
- fcntl.flock(f, fcntl.LOCK_EX)
110
- json.dump(self.cache_construct, f, indent=1)
111
- fcntl.flock(f, fcntl.LOCK_UN)
112
-
113
- def write_json(self):
114
- self.write_data_json(self.dump_file_path)
115
- self.write_stack_info_json(self.stack_file_path)
116
- self.write_construct_info_json(self.construct_file_path)
1
+ import os
2
+ import csv
3
+
4
+ from msprobe.core.common.file_utils import change_mode, FileOpen
5
+ from msprobe.core.common.log import logger
6
+ from msprobe.core.common.const import Const, FileCheckConst
7
+ from msprobe.core.common.file_utils import remove_path, load_json, save_json
8
+
9
+
10
+ class DataWriter:
11
+
12
+ def __init__(self, init_json=None) -> None:
13
+ self.dump_count = 0
14
+ self.init_json = init_json
15
+ self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
16
+ self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
17
+ self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
18
+ self.free_benchmark_file_path = None
19
+ self.dump_tensor_data_dir = None
20
+ self.buffer_size = 1000
21
+ self.cache_data = {Const.DATA: {}}
22
+ self.cache_stack = {}
23
+ self.cache_construct = {}
24
+
25
+ @staticmethod
26
+ def write_data_to_csv(result: list, result_header: tuple, file_path: str):
27
+ if not result:
28
+ return
29
+ is_exists = os.path.exists(file_path)
30
+ append = "a+" if is_exists else "w+"
31
+ with FileOpen(file_path, append) as csv_file:
32
+ spawn_writer = csv.writer(csv_file)
33
+ if not is_exists:
34
+ spawn_writer.writerow(result_header)
35
+ spawn_writer.writerows([result,])
36
+ is_new_file = not is_exists
37
+ if is_new_file:
38
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
39
+
40
+ def initialize_json_file(self, **kwargs):
41
+ kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
42
+ save_json(self.dump_file_path, kwargs)
43
+
44
+ empty_dict = {}
45
+ remove_path(self.stack_file_path)
46
+ save_json(self.stack_file_path, empty_dict)
47
+
48
+ remove_path(self.construct_file_path)
49
+ save_json(self.construct_file_path, empty_dict)
50
+
51
+ def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
52
+ free_benchmark_file_path):
53
+ self.dump_file_path = dump_file_path
54
+ self.stack_file_path = stack_file_path
55
+ self.construct_file_path = construct_file_path
56
+ self.dump_tensor_data_dir = dump_data_dir
57
+ self.free_benchmark_file_path = free_benchmark_file_path
58
+
59
+ def update_data(self, new_data):
60
+ key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
61
+ if key in self.cache_data[Const.DATA]:
62
+ self.cache_data[Const.DATA][key].update(new_data[key])
63
+ else:
64
+ self.cache_data[Const.DATA].update(new_data)
65
+
66
+ def flush_data_when_buffer_is_full(self):
67
+ if len(self.cache_data[Const.DATA]) >= self.buffer_size:
68
+ self.write_data_json(self.dump_file_path)
69
+
70
+ def update_stack(self, new_data):
71
+ self.cache_stack.update(new_data)
72
+
73
+ def update_construct(self, new_data):
74
+ self.cache_construct.update(new_data)
75
+
76
+ def write_data_json(self, file_path):
77
+ logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
78
+ if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
79
+ data_to_write = load_json(file_path)
80
+ else:
81
+ self.init_json['data_path'] = self.dump_tensor_data_dir
82
+ data_to_write = self.init_json
83
+ data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
84
+ save_json(file_path, data_to_write, indent=1)
85
+ self.cache_data[Const.DATA].clear()
86
+
87
+ def write_stack_info_json(self, file_path):
88
+ save_json(file_path, self.cache_stack, indent=1)
89
+
90
+ def write_construct_info_json(self, file_path):
91
+ save_json(file_path, self.cache_construct, indent=1)
92
+
93
+ def write_json(self):
94
+ self.write_data_json(self.dump_file_path)
95
+ self.write_stack_info_json(self.stack_file_path)
96
+ self.write_construct_info_json(self.construct_file_path)
@@ -1,178 +1,178 @@
1
- from abc import ABC, abstractmethod
2
- from msprobe.core.common.exceptions import ScopeException
3
- from msprobe.core.common.const import Const
4
-
5
-
6
- def build_scope(scope_class, scope=None, api_list=None):
7
- if not scope and not api_list:
8
- return None
9
- if scope is None:
10
- scope = []
11
- if api_list is None:
12
- api_list = []
13
- if scope_class:
14
- return scope_class(scope, api_list)
15
- return build_range_scope_according_to_scope_name(scope, api_list)
16
-
17
-
18
- def build_range_scope_according_to_scope_name(scope, api_list):
19
- api_range_scope = APIRangeScope(scope, api_list)
20
- module_range_scope = ModuleRangeScope(scope, api_list)
21
- if not scope: # 如果没有scope参数则用哪类scope都一样
22
- return api_range_scope
23
- if api_range_scope.is_valid and module_range_scope.is_valid:
24
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
25
- elif api_range_scope.is_valid:
26
- return api_range_scope
27
- elif module_range_scope.is_valid:
28
- return module_range_scope
29
- else:
30
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
31
-
32
-
33
- class BaseScope(ABC):
34
- Module_Type_Module = "Module"
35
- Module_Type_API = "api"
36
-
37
- def __init__(self, scope, api_list):
38
- scope, api_list = self.rectify_args(scope, api_list)
39
- self.scope = scope
40
- self.api_list = api_list
41
-
42
- @staticmethod
43
- def rectify_args(scope, api_list):
44
- if not isinstance(api_list, list):
45
- raise ScopeException(ScopeException.InvalidApiStr,
46
- f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
47
- for api in api_list:
48
- if not isinstance(api, str):
49
- raise ScopeException(ScopeException.InvalidApiStr,
50
- f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
51
- if isinstance(scope, str):
52
- scope = [scope]
53
- return scope, api_list
54
- if not isinstance(scope, list):
55
- raise ScopeException(ScopeException.InvalidScope,
56
- f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
57
- for s in scope:
58
- if not isinstance(s, str):
59
- raise ScopeException(ScopeException.InvalidScope,
60
- f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
61
- return scope, api_list
62
-
63
- @abstractmethod
64
- def check(self, name):
65
- pass
66
-
67
- def check_api_list(self, api_name):
68
- if not self.api_list:
69
- return True
70
- for api_str in self.api_list:
71
- if api_str in api_name:
72
- return True
73
- return False
74
-
75
-
76
- class ListScope(BaseScope):
77
- @staticmethod
78
- def rectify_args(scope, api_list):
79
- if scope and api_list:
80
- raise ScopeException(ScopeException.ArgConflict,
81
- f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
- return super(ListScope, ListScope).rectify_args(scope, api_list)
83
-
84
- def check(self, module_name):
85
- if not self.scope or module_name in self.scope:
86
- return self.check_api_list(module_name)
87
- return False
88
-
89
-
90
- class RangeScope(BaseScope, ABC):
91
-
92
- def __init__(self, *args):
93
- super().__init__(*args)
94
- self.in_scope = False
95
- self.is_valid = self.check_scope_is_valid()
96
-
97
-
98
- @staticmethod
99
- def rectify_args(scope, api_list):
100
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
101
- if isinstance(scope, list):
102
- if len(scope) == 1:
103
- scope.append(scope[0])
104
- elif len(scope) > 2:
105
- raise ScopeException(ScopeException.InvalidScope,
106
- f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
107
-
108
- return scope, api_list
109
-
110
- @abstractmethod
111
- def check_scope_is_valid(self):
112
- pass
113
-
114
- def begin_module(self, module_name):
115
- pass
116
-
117
- def end_module(self, module_name):
118
- pass
119
-
120
-
121
- class APIRangeScope(RangeScope):
122
- def check_scope_is_valid(self):
123
- if not self.scope:
124
- return True
125
- scope_start_type = self.scope[0].split(Const.SEP)[0]
126
- if scope_start_type == BaseScope.Module_Type_Module:
127
- return False
128
- scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
- if scope_stop_type == BaseScope.Module_Type_Module:
130
- return False
131
- return True
132
-
133
- def check(self, api_name):
134
- if self.scope and api_name == self.scope[0]:
135
- self.in_scope = True
136
-
137
- if not self.scope or self.in_scope:
138
- result = self.check_api_list(api_name)
139
- else:
140
- result = False
141
-
142
- if self.scope and api_name == self.scope[1]:
143
- self.in_scope = False
144
- return result
145
-
146
-
147
- class ModuleRangeScope(RangeScope):
148
- """
149
- 模块与api不同的是,模块内部还有子结构需要dump,
150
- 需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
- 在这些hook触发时调用begin_module和end_module做区间控制
152
- """
153
- def check_scope_is_valid(self):
154
- if not self.scope:
155
- return True
156
- scope_start_type = self.scope[0].split(Const.SEP)[0]
157
- scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
- if scope_start_type == BaseScope.Module_Type_Module and \
159
- scope_stop_type == BaseScope.Module_Type_Module:
160
- return True
161
- return False
162
-
163
- def begin_module(self, module_name):
164
- if not self.scope:
165
- return
166
- if module_name == self.scope[0]:
167
- self.in_scope = True
168
-
169
- def end_module(self, module_name):
170
- if not self.scope:
171
- return
172
- if module_name == self.scope[1]:
173
- self.in_scope = False
174
-
175
- def check(self, module_name):
176
- if not self.scope or self.in_scope:
177
- return self.check_api_list(module_name)
178
- return False
1
+ from abc import ABC, abstractmethod
2
+ from msprobe.core.common.exceptions import ScopeException
3
+ from msprobe.core.common.const import Const
4
+
5
+
6
+ def build_scope(scope_class, scope=None, api_list=None):
7
+ if not scope and not api_list:
8
+ return None
9
+ if scope is None:
10
+ scope = []
11
+ if api_list is None:
12
+ api_list = []
13
+ if scope_class:
14
+ return scope_class(scope, api_list)
15
+ return build_range_scope_according_to_scope_name(scope, api_list)
16
+
17
+
18
+ def build_range_scope_according_to_scope_name(scope, api_list):
19
+ api_range_scope = APIRangeScope(scope, api_list)
20
+ module_range_scope = ModuleRangeScope(scope, api_list)
21
+ if not scope: # 如果没有scope参数则用哪类scope都一样
22
+ return api_range_scope
23
+ if api_range_scope.is_valid and module_range_scope.is_valid:
24
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
25
+ elif api_range_scope.is_valid:
26
+ return api_range_scope
27
+ elif module_range_scope.is_valid:
28
+ return module_range_scope
29
+ else:
30
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
31
+
32
+
33
+ class BaseScope(ABC):
34
+ Module_Type_Module = "Module"
35
+ Module_Type_API = "api"
36
+
37
+ def __init__(self, scope, api_list):
38
+ scope, api_list = self.rectify_args(scope, api_list)
39
+ self.scope = scope
40
+ self.api_list = api_list
41
+
42
+ @staticmethod
43
+ def rectify_args(scope, api_list):
44
+ if not isinstance(api_list, list):
45
+ raise ScopeException(ScopeException.InvalidApiStr,
46
+ f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
47
+ for api in api_list:
48
+ if not isinstance(api, str):
49
+ raise ScopeException(ScopeException.InvalidApiStr,
50
+ f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
51
+ if isinstance(scope, str):
52
+ scope = [scope]
53
+ return scope, api_list
54
+ if not isinstance(scope, list):
55
+ raise ScopeException(ScopeException.InvalidScope,
56
+ f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
57
+ for s in scope:
58
+ if not isinstance(s, str):
59
+ raise ScopeException(ScopeException.InvalidScope,
60
+ f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
61
+ return scope, api_list
62
+
63
+ @abstractmethod
64
+ def check(self, name):
65
+ pass
66
+
67
+ def check_api_list(self, api_name):
68
+ if not self.api_list:
69
+ return True
70
+ for api_str in self.api_list:
71
+ if api_str in api_name:
72
+ return True
73
+ return False
74
+
75
+
76
+ class ListScope(BaseScope):
77
+ @staticmethod
78
+ def rectify_args(scope, api_list):
79
+ if scope and api_list:
80
+ raise ScopeException(ScopeException.ArgConflict,
81
+ f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
+ return super(ListScope, ListScope).rectify_args(scope, api_list)
83
+
84
+ def check(self, module_name):
85
+ if not self.scope or module_name in self.scope:
86
+ return self.check_api_list(module_name)
87
+ return False
88
+
89
+
90
+ class RangeScope(BaseScope, ABC):
91
+
92
+ def __init__(self, *args):
93
+ super().__init__(*args)
94
+ self.in_scope = False
95
+ self.is_valid = self.check_scope_is_valid()
96
+
97
+
98
+ @staticmethod
99
+ def rectify_args(scope, api_list):
100
+ scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
101
+ if isinstance(scope, list):
102
+ if len(scope) == 1:
103
+ scope.append(scope[0])
104
+ elif len(scope) > 2:
105
+ raise ScopeException(ScopeException.InvalidScope,
106
+ f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
107
+
108
+ return scope, api_list
109
+
110
+ @abstractmethod
111
+ def check_scope_is_valid(self):
112
+ pass
113
+
114
+ def begin_module(self, module_name):
115
+ pass
116
+
117
+ def end_module(self, module_name):
118
+ pass
119
+
120
+
121
+ class APIRangeScope(RangeScope):
122
+ def check_scope_is_valid(self):
123
+ if not self.scope:
124
+ return True
125
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
126
+ if scope_start_type == BaseScope.Module_Type_Module:
127
+ return False
128
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
+ if scope_stop_type == BaseScope.Module_Type_Module:
130
+ return False
131
+ return True
132
+
133
+ def check(self, api_name):
134
+ if self.scope and api_name == self.scope[0]:
135
+ self.in_scope = True
136
+
137
+ if not self.scope or self.in_scope:
138
+ result = self.check_api_list(api_name)
139
+ else:
140
+ result = False
141
+
142
+ if self.scope and api_name == self.scope[1]:
143
+ self.in_scope = False
144
+ return result
145
+
146
+
147
+ class ModuleRangeScope(RangeScope):
148
+ """
149
+ 模块与api不同的是,模块内部还有子结构需要dump,
150
+ 需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
+ 在这些hook触发时调用begin_module和end_module做区间控制
152
+ """
153
+ def check_scope_is_valid(self):
154
+ if not self.scope:
155
+ return True
156
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
157
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
+ if scope_start_type == BaseScope.Module_Type_Module and \
159
+ scope_stop_type == BaseScope.Module_Type_Module:
160
+ return True
161
+ return False
162
+
163
+ def begin_module(self, module_name):
164
+ if not self.scope:
165
+ return
166
+ if module_name == self.scope[0]:
167
+ self.in_scope = True
168
+
169
+ def end_module(self, module_name):
170
+ if not self.scope:
171
+ return
172
+ if module_name == self.scope[1]:
173
+ self.in_scope = False
174
+
175
+ def check(self, module_name):
176
+ if not self.scope or self.in_scope:
177
+ return self.check_api_list(module_name)
178
+ return False
File without changes
@@ -0,0 +1,71 @@
1
+
2
+ class GradConst:
3
+
4
+ FRAMEWORKS = {"PyTorch", "MindSpore"}
5
+ PYTORCH = "PyTorch"
6
+ MindSpore = "MindSpore"
7
+
8
+ GRAD_FILE_SUFFIX = {"npy", "pt"}
9
+ NPY_SUFFIX = "npy"
10
+ PT_SUFFIX = "pt"
11
+
12
+ # for callback
13
+ CURRENT_STEP = "current_step"
14
+
15
+ PARAM_LIST = "param_list"
16
+ RANK = "rank"
17
+ STEP = "step"
18
+ BOUNDS = "bounds"
19
+ OUTPUT_PATH = "output_path"
20
+
21
+ # level const
22
+ LEVEL = "level"
23
+ LEVEL0 = "L0"
24
+ LEVEL1 = "L1"
25
+ LEVEL2 = "L2"
26
+ SUPPORTED_LEVEL = {"L0", "L1", "L2"}
27
+
28
+ # numpy coding
29
+ STEP_IDX = 0
30
+ SHAPE_DIM_IDX = 4
31
+ MAX_SIZE = 10 * 1024 * 1024 * 1024
32
+
33
+ # direction suffix
34
+ DIR_SUFFIX = "dir.npy"
35
+
36
+ # file safty
37
+ DATA_DIR_AUTHORITY = 0o750
38
+ DATA_FILE_AUTHORITY = 0o640
39
+ DIRECTORY_LENGTH = 4096
40
+ FILE_NAME_LENGTH = 255
41
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
42
+ PARAM_VALID_PATTERN = r"^[a-zA-Z0-9_.]+$"
43
+ DIR = "dir"
44
+ FILE = "file"
45
+
46
+ STEP_FINISH = "step_finish"
47
+
48
+ SUMMARY = "summary"
49
+
50
+ # csv header entry
51
+ MD5 = "MD5"
52
+ DISTRIBUTION = "distribution"
53
+ SHAPE = "shape"
54
+ MAX = "max"
55
+ MIN = "min"
56
+ NORM = "norm"
57
+
58
+ level_adp = {
59
+ "L0": {
60
+ "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
61
+ "have_grad_direction": False
62
+ },
63
+ "L1": {
64
+ "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
65
+ "have_grad_direction": True
66
+ },
67
+ "L2": {
68
+ "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
69
+ "have_grad_direction": True
70
+ },
71
+ }