mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (323) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -182
  7. msprobe/__init__.py +1 -0
  8. msprobe/{config/config.json → config.json} +49 -27
  9. msprobe/core/__init__.py +0 -0
  10. msprobe/{pytorch → core}/advisor/advisor.py +124 -124
  11. msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
  12. msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
  13. msprobe/core/common/const.py +341 -241
  14. msprobe/core/common/exceptions.py +100 -88
  15. msprobe/core/common/{file_check.py → file_utils.py} +478 -265
  16. msprobe/core/common/log.py +76 -55
  17. msprobe/core/common/utils.py +385 -516
  18. msprobe/core/common_config.py +85 -58
  19. msprobe/core/compare/acc_compare.py +300 -0
  20. msprobe/core/compare/check.py +95 -0
  21. msprobe/core/compare/compare_cli.py +49 -0
  22. msprobe/core/compare/highlight.py +223 -0
  23. msprobe/core/compare/multiprocessing_compute.py +149 -0
  24. msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
  25. msprobe/core/compare/utils.py +430 -0
  26. msprobe/core/data_dump/data_collector.py +154 -140
  27. msprobe/core/data_dump/data_processor/base.py +314 -245
  28. msprobe/core/data_dump/data_processor/factory.py +59 -61
  29. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
  30. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
  31. msprobe/core/data_dump/json_writer.py +96 -116
  32. msprobe/core/data_dump/scope.py +178 -178
  33. msprobe/core/grad_probe/__init__.py +0 -0
  34. msprobe/core/grad_probe/constant.py +71 -0
  35. msprobe/core/grad_probe/grad_compare.py +171 -0
  36. msprobe/core/grad_probe/utils.py +64 -0
  37. msprobe/docs/01.installation.md +89 -0
  38. msprobe/docs/02.config_introduction.md +165 -0
  39. msprobe/docs/03.config_examples.md +247 -0
  40. msprobe/docs/04.acl_config_examples.md +76 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  52. msprobe/docs/17.grad_probe.md +207 -0
  53. msprobe/docs/FAQ_PyTorch.md +177 -0
  54. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  55. msprobe/docs/img/free_benchmark_framework.png +0 -0
  56. msprobe/docs/img/grad_probe_image-1.png +0 -0
  57. msprobe/docs/img/grad_probe_image-2.png +0 -0
  58. msprobe/docs/img/grad_probe_image-3.png +0 -0
  59. msprobe/docs/img/grad_probe_image-4.png +0 -0
  60. msprobe/docs/img/grad_probe_image.png +0 -0
  61. msprobe/mindspore/__init__.py +1 -1
  62. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  63. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
  64. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  65. msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
  66. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  67. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  68. msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
  69. msprobe/mindspore/api_accuracy_checker/main.py +9 -0
  70. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  71. msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
  72. msprobe/mindspore/cell_processor.py +34 -0
  73. msprobe/mindspore/common/const.py +106 -0
  74. msprobe/mindspore/common/log.py +38 -0
  75. msprobe/mindspore/common/utils.py +81 -0
  76. msprobe/mindspore/compare/distributed_compare.py +75 -0
  77. msprobe/mindspore/compare/ms_compare.py +219 -0
  78. msprobe/mindspore/compare/ms_graph_compare.py +348 -0
  79. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  80. msprobe/mindspore/debugger/debugger_config.py +66 -51
  81. msprobe/mindspore/debugger/precision_debugger.py +126 -32
  82. msprobe/mindspore/dump/dump_tool_factory.py +35 -38
  83. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
  86. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  87. msprobe/mindspore/dump/jit_dump.py +72 -0
  88. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  89. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
  90. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  91. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  92. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  93. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  95. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  97. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  98. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
  99. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  100. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
  110. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  111. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
  112. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  113. msprobe/mindspore/grad_probe/__init__.py +0 -0
  114. msprobe/mindspore/grad_probe/global_context.py +90 -0
  115. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  116. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  117. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  118. msprobe/mindspore/grad_probe/hook.py +94 -0
  119. msprobe/mindspore/grad_probe/utils.py +30 -0
  120. msprobe/mindspore/ms_config.py +128 -78
  121. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  122. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
  123. msprobe/mindspore/runtime.py +4 -0
  124. msprobe/mindspore/service.py +378 -0
  125. msprobe/mindspore/task_handler_factory.py +24 -21
  126. msprobe/msprobe.py +105 -67
  127. msprobe/pytorch/__init__.py +4 -4
  128. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
  129. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
  130. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
  131. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
  132. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  133. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  134. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
  135. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  136. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
  137. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
  138. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
  139. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
  140. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
  141. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
  142. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
  143. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  148. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
  149. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  150. msprobe/pytorch/bench_functions/__init__.py +15 -0
  151. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  152. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  153. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  154. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  155. msprobe/pytorch/bench_functions/linear.py +12 -0
  156. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  157. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
  158. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  159. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  160. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  161. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  162. msprobe/pytorch/common/__init__.py +2 -2
  163. msprobe/pytorch/common/compare_script.template +14 -14
  164. msprobe/pytorch/common/log.py +20 -31
  165. msprobe/pytorch/common/parse_json.py +39 -37
  166. msprobe/pytorch/common/utils.py +305 -224
  167. msprobe/pytorch/compare/distributed_compare.py +66 -111
  168. msprobe/pytorch/compare/mapping.yaml +607 -607
  169. msprobe/pytorch/compare/match.py +34 -36
  170. msprobe/pytorch/compare/pt_compare.py +50 -0
  171. msprobe/pytorch/debugger/debugger_config.py +95 -86
  172. msprobe/pytorch/debugger/precision_debugger.py +125 -95
  173. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  174. msprobe/pytorch/free_benchmark/common/constant.py +70 -67
  175. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  176. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  177. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  178. msprobe/pytorch/free_benchmark/common/utils.py +102 -98
  179. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
  180. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  181. msprobe/pytorch/free_benchmark/main.py +105 -102
  182. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  183. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  188. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  189. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  190. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  191. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
  192. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  193. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  194. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
  195. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  196. msprobe/pytorch/function_factory.py +76 -0
  197. msprobe/pytorch/functional/dump_module.py +39 -39
  198. msprobe/pytorch/grad_probe/__init__.py +0 -0
  199. msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
  200. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  201. msprobe/pytorch/hook_module/api_registry.py +161 -161
  202. msprobe/pytorch/hook_module/hook_module.py +120 -109
  203. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
  204. msprobe/pytorch/hook_module/utils.py +30 -29
  205. msprobe/pytorch/hook_module/wrap_aten.py +110 -100
  206. msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
  207. msprobe/pytorch/hook_module/wrap_functional.py +105 -108
  208. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
  209. msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
  210. msprobe/pytorch/hook_module/wrap_torch.py +86 -88
  211. msprobe/pytorch/hook_module/wrap_vf.py +62 -64
  212. msprobe/pytorch/module_processer.py +138 -98
  213. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  214. msprobe/pytorch/online_dispatch/compare.py +236 -236
  215. msprobe/pytorch/online_dispatch/dispatch.py +271 -273
  216. msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
  217. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  218. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  219. msprobe/pytorch/online_dispatch/utils.py +130 -187
  220. msprobe/pytorch/parse.py +4 -4
  221. msprobe/pytorch/parse_tool/cli.py +32 -32
  222. msprobe/pytorch/parse_tool/lib/compare.py +260 -259
  223. msprobe/pytorch/parse_tool/lib/config.py +52 -51
  224. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  225. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  226. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  227. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  228. msprobe/pytorch/parse_tool/lib/utils.py +316 -367
  229. msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
  230. msprobe/pytorch/pt_config.py +188 -93
  231. msprobe/pytorch/service.py +246 -167
  232. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  233. msprobe/config/README.md +0 -397
  234. msprobe/mindspore/doc/dump.md +0 -65
  235. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  236. msprobe/pytorch/compare/acc_compare.py +0 -1024
  237. msprobe/pytorch/compare/highlight.py +0 -100
  238. msprobe/pytorch/doc/FAQ.md +0 -193
  239. msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
  240. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  241. msprobe/pytorch/doc/dump.md +0 -207
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  247. msprobe/test/core_ut/common/test_utils.py +0 -345
  248. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  249. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  250. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  251. msprobe/test/core_ut/test_common_config.py +0 -152
  252. msprobe/test/core_ut/test_file_check.py +0 -218
  253. msprobe/test/core_ut/test_log.py +0 -109
  254. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  255. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  256. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  257. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  258. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  259. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  260. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  261. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  262. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  263. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  264. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  265. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  266. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  267. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  268. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  269. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  270. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  271. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  272. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  273. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  274. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  275. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  276. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  277. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  278. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  279. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  280. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  281. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  282. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  283. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  284. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  285. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  286. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  287. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  288. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  289. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  290. msprobe/test/pytorch_ut/test_service.py +0 -59
  291. msprobe/test/resources/advisor.txt +0 -3
  292. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  293. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  294. msprobe/test/resources/config.yaml +0 -3
  295. msprobe/test/resources/npu_test.pkl +0 -8
  296. msprobe/test/run_test.sh +0 -30
  297. msprobe/test/run_ut.py +0 -58
  298. msprobe/test/test_module_processer.py +0 -64
  299. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  300. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  301. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  302. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  303. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  304. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  305. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  306. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  307. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  308. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  309. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  310. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  311. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  312. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  313. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  314. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  315. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  316. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  317. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  318. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  319. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  320. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  321. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  322. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  323. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,345 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- import os
18
- import uuid
19
-
20
- from unittest import TestCase
21
- from unittest.mock import patch, MagicMock, mock_open
22
-
23
- from msprobe.core.common.log import logger
24
- from msprobe.core.common.const import Const
25
- from msprobe.core.common.utils import (CompareException,
26
- check_seed_all,
27
- check_inplace_op,
28
- make_dump_path_if_not_exists,
29
- check_mode_valid,
30
- check_switch_valid,
31
- check_dump_mode_valid,
32
- check_summary_mode_valid,
33
- check_summary_only_valid,
34
- check_file_or_directory_path,
35
- check_compare_param,
36
- check_configuration_param,
37
- is_starts_with,
38
- _check_json,
39
- check_json_file,
40
- check_file_size,
41
- check_regex_prefix_format_valid,
42
- get_dump_data_path,
43
- task_dumppath_get)
44
- from msprobe.core.common.file_check import FileCheckConst
45
-
46
-
47
- class TestUtils(TestCase):
48
- @patch.object(logger, "error")
49
- def test_check_seed_all(self, mock_error):
50
- self.assertIsNone(check_seed_all(1234, True))
51
- self.assertIsNone(check_seed_all(0, True))
52
- self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True))
53
-
54
- with self.assertRaises(CompareException) as context:
55
- check_seed_all(-1, True)
56
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
57
- mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
58
-
59
- with self.assertRaises(CompareException) as context:
60
- check_seed_all(Const.MAX_SEED_VALUE + 1, True)
61
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
62
- mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
63
-
64
- with self.assertRaises(CompareException) as context:
65
- check_seed_all("1234", True)
66
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
67
- mock_error.assert_called_with("Seed must be integer.")
68
-
69
- with self.assertRaises(CompareException) as context:
70
- check_seed_all(1234, 1)
71
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
72
- mock_error.assert_called_with("seed_all mode must be bool.")
73
-
74
- def test_check_inplace_op(self):
75
- test_prefix_1 = "Distributed.broadcast.0.forward.input.0"
76
- self.assertTrue(check_inplace_op(test_prefix_1))
77
- test_prefix_2 = "Distributed_broadcast_0_forward_input_0"
78
- self.assertFalse(check_inplace_op(test_prefix_2))
79
- test_prefix_3 = "Torch.sum.0.backward.output.0"
80
- self.assertFalse(check_inplace_op(test_prefix_3))
81
-
82
- @patch.object(logger, "error")
83
- def test_make_dump_path_if_not_exists(self, mock_error):
84
- file_path = os.path.realpath(__file__)
85
- dirname = os.path.dirname(file_path) + str(uuid.uuid4())
86
-
87
- def test_mkdir(self, **kwargs):
88
- raise OSError
89
-
90
- if not os.path.exists(dirname):
91
- with patch("msprobe.core.common.utils.Path.mkdir", new=test_mkdir):
92
- with self.assertRaises(CompareException) as context:
93
- make_dump_path_if_not_exists(dirname)
94
- self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
95
-
96
- make_dump_path_if_not_exists(file_path)
97
- mock_error.assert_called_with(f"{file_path} already exists and is not a directory.")
98
-
99
- def test_check_mode_valid(self):
100
- with self.assertRaises(ValueError) as context:
101
- check_mode_valid("all", scope="scope")
102
- self.assertEqual(str(context.exception), "scope param set invalid, it's must be a list.")
103
-
104
- with self.assertRaises(ValueError) as context:
105
- check_mode_valid("all", api_list="api_list")
106
- self.assertEqual(str(context.exception), "api_list param set invalid, it's must be a list.")
107
-
108
- mode = "all_list"
109
- with self.assertRaises(CompareException) as context:
110
- check_mode_valid(mode)
111
- self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_MODE)
112
- self.assertEqual(str(context.exception),
113
- f"Current mode '{mode}' is not supported. Please use the field in {Const.DUMP_MODE}")
114
-
115
- mode = "list"
116
- with self.assertRaises(ValueError) as context:
117
- check_mode_valid(mode)
118
- self.assertEqual(str(context.exception),
119
- "set_dump_switch, scope param set invalid, it's should not be an empty list.")
120
-
121
- @patch.object(logger, "error")
122
- def test_check_switch_valid(self, mock_error):
123
- with self.assertRaises(CompareException) as context:
124
- check_switch_valid("Close")
125
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
126
- mock_error.assert_called_with("Please set switch with 'ON' or 'OFF'.")
127
-
128
- @patch.object(logger, "warning")
129
- def test_check_dump_mode_valid(self, mock_warning):
130
- dump_mode = check_dump_mode_valid("all")
131
- mock_warning.assert_called_with("Please set dump_mode as a list.")
132
- self.assertEqual(dump_mode, ["forward", "backward", "input", "output"])
133
-
134
- with self.assertRaises(ValueError) as context:
135
- check_dump_mode_valid("all_forward")
136
- self.assertEqual(str(context.exception),
137
- "Please set dump_mode as a list containing one or more of the following: " +
138
- "'all', 'forward', 'backward', 'input', 'output'.")
139
-
140
- def test_check_summary_mode_valid(self):
141
- with self.assertRaises(CompareException) as context:
142
- check_summary_mode_valid("MD5")
143
- self.assertEqual(context.exception.code, CompareException.INVALID_SUMMARY_MODE)
144
- self.assertEqual(str(context.exception), "The summary_mode is not valid")
145
-
146
- @patch.object(logger, "error")
147
- def test_check_summary_only_valid(self, mock_error):
148
- summary_only = check_summary_only_valid(True)
149
- self.assertTrue(summary_only)
150
-
151
- with self.assertRaises(CompareException) as context:
152
- check_summary_only_valid("True")
153
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
154
- mock_error.assert_called_with("Params summary_only only support True or False.")
155
-
156
- def test_check_file_or_directory_path(self):
157
- class TestFileChecker:
158
- file_path = ""
159
- path_type = ""
160
- ability = ""
161
- checked = False
162
-
163
- def __init__(self, file_path, path_type, ability=None):
164
- TestFileChecker.file_path = file_path
165
- TestFileChecker.path_type = path_type
166
- TestFileChecker.ability = ability
167
-
168
- def common_check(self):
169
- TestFileChecker.checked = True
170
-
171
- file_path = os.path.realpath(__file__)
172
- dirname = os.path.dirname(file_path)
173
-
174
- with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
175
- check_file_or_directory_path(file_path, isdir=False)
176
- self.assertTrue(TestFileChecker.checked)
177
- self.assertEqual(TestFileChecker.file_path, file_path)
178
- self.assertEqual(TestFileChecker.path_type, FileCheckConst.FILE)
179
- self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE)
180
-
181
- TestFileChecker.checked = False
182
- with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
183
- check_file_or_directory_path(dirname, isdir=True)
184
- self.assertTrue(TestFileChecker.checked)
185
- self.assertEqual(TestFileChecker.file_path, dirname)
186
- self.assertEqual(TestFileChecker.path_type, FileCheckConst.DIR)
187
- self.assertEqual(TestFileChecker.ability, FileCheckConst.WRITE_ABLE)
188
-
189
- @patch.object(logger, "error")
190
- def test_check_compare_param(self, mock_error):
191
- params = {
192
- "npu_json_path": "npu_json_path",
193
- "bench_json_path": "bench_json_path",
194
- "stack_json_path": "stack_json_path",
195
- "npu_dump_data_dir": "npu_dump_data_dir",
196
- "bench_dump_data_dir": "bench_dump_data_dir"
197
- }
198
-
199
- call_args = [
200
- ("npu_json_path", False),
201
- ("bench_json_path", False),
202
- ("stack_json_path", False),
203
- ("npu_dump_data_dir", True),
204
- ("bench_dump_data_dir", True),
205
- ("output_path", True),
206
- ("npu_json_path", False),
207
- ("bench_json_path", False),
208
- ("stack_json_path", False),
209
- ("output_path", True)
210
- ]
211
-
212
- with self.assertRaises(CompareException) as context:
213
- check_compare_param("npu_json_path", "output_path")
214
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
215
- mock_error.assert_called_with("Invalid input parameters")
216
-
217
- mock_check_file_or_directory_path = MagicMock()
218
- mock_check_json_file = MagicMock()
219
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
220
- patch("msprobe.core.common.utils.check_json_file", new=mock_check_json_file), \
221
- patch("msprobe.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path):
222
- check_compare_param(params, "output_path")
223
- check_compare_param(params, "output_path", summary_compare=False, md5_compare=True)
224
- for i in range(len(call_args)):
225
- self.assertEqual(mock_check_file_or_directory_path.call_args_list[i][0], call_args[i])
226
- self.assertEqual(len(mock_check_json_file.call_args[0]), 4)
227
- self.assertEqual(mock_check_json_file.call_args[0][0], params)
228
-
229
- @patch.object(logger, "error")
230
- def test_check_configuration_param(self, mock_error):
231
- with self.assertRaises(CompareException) as context:
232
- check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False)
233
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
234
- mock_error.assert_called_with("Invalid input parameters which should be only bool type.")
235
-
236
- def test_is_starts_with(self):
237
- string = "input_slot0"
238
- self.assertFalse(is_starts_with(string, []))
239
- self.assertFalse(is_starts_with("", ["input"]))
240
- self.assertFalse(is_starts_with(string, ["output"]))
241
- self.assertTrue(is_starts_with(string, ["input", "output"]))
242
-
243
- @patch.object(logger, "error")
244
- def test__check_json(self, mock_error):
245
- class TestOpen:
246
- def __init__(self, string):
247
- self.string = string
248
-
249
- def readline(self):
250
- return self.string
251
-
252
- def seek(self, begin, end):
253
- self.string = str(begin) + "_" + str(end)
254
-
255
- with self.assertRaises(CompareException) as context:
256
- _check_json(TestOpen(""), "test.json")
257
- self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_FILE)
258
- mock_error.assert_called_with("dump file test.json have empty line!")
259
-
260
- handler = TestOpen("jons file\n")
261
- _check_json(handler, "test.json")
262
- self.assertEqual(handler.string, "0_0")
263
-
264
- @patch("msprobe.core.common.utils._check_json")
265
- def test_check_json_file(self, _mock_check_json):
266
- input_param = {
267
- "npu_json_path": "npu_json_path",
268
- "bench_json_path": "bench_json_path",
269
- "stack_json_path": "stack_json_path"
270
- }
271
- check_json_file(input_param, "npu_json", "bench_json", "stack_json")
272
- self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path"))
273
- self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path"))
274
- self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path"))
275
-
276
- @patch.object(logger, "error")
277
- def test_check_file_size(self, mock_error):
278
- with patch("msprobe.core.common.utils.os.path.getsize", return_value=120):
279
- with self.assertRaises(CompareException) as context:
280
- check_file_size("input_file", 100)
281
- self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR)
282
- mock_error.assert_called_with("The size (120) of input_file exceeds (100) bytes, tools not support.")
283
-
284
- def test_check_regex_prefix_format_valid(self):
285
- prefix = "A" * 21
286
- with self.assertRaises(ValueError) as context:
287
- check_regex_prefix_format_valid(prefix)
288
- self.assertEqual(str(context.exception), f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, "
289
- f"while current length is {len(prefix)}")
290
-
291
- prefix = "(prefix)"
292
- with self.assertRaises(ValueError) as context:
293
- check_regex_prefix_format_valid(prefix)
294
- self.assertEqual(str(context.exception), f"prefix contains invalid characters, "
295
- f"prefix pattern {Const.REGEX_PREFIX_PATTERN}")
296
-
297
- @patch("msprobe.core.common.utils.check_file_or_directory_path")
298
- def test_get_dump_data_path(self, mock_check_file_or_directory_path):
299
- file_path = os.path.realpath(__file__)
300
- dirname = os.path.dirname(file_path)
301
-
302
- dump_data_path, file_is_exist = get_dump_data_path(dirname)
303
- self.assertEqual(mock_check_file_or_directory_path.call_args[0], (dirname, True))
304
- self.assertEqual(dump_data_path, dirname)
305
- self.assertTrue(file_is_exist)
306
-
307
- @patch.object(logger, "error")
308
- def test_task_dumppath_get(self, mock_error):
309
- input_param = {
310
- "npu_json_path": None,
311
- "bench_json_path": "bench_json_path"
312
- }
313
- npu_json = {
314
- "task": Const.TENSOR,
315
- "dump_data_dir": "dump_data_dir",
316
- "data": "data"
317
- }
318
-
319
- with self.assertRaises(CompareException) as context:
320
- task_dumppath_get(input_param)
321
- self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
322
- mock_error.assert_called_with("Please check the json path is valid.")
323
-
324
- input_param["npu_json_path"] = "npu_json_path"
325
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
326
- patch("msprobe.core.common.utils.json.load", return_value=npu_json):
327
- summary_compare, md5_compare = task_dumppath_get(input_param)
328
- self.assertFalse(summary_compare)
329
- self.assertFalse(md5_compare)
330
-
331
- npu_json["task"] = Const.STATISTICS
332
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
333
- patch("msprobe.core.common.utils.json.load", return_value=npu_json), \
334
- patch("msprobe.core.common.utils.md5_find", return_value=True):
335
- summary_compare, md5_compare = task_dumppath_get(input_param)
336
- self.assertFalse(summary_compare)
337
- self.assertTrue(md5_compare)
338
-
339
- npu_json["task"] = Const.OVERFLOW_CHECK
340
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
341
- patch("msprobe.core.common.utils.json.load", return_value=npu_json):
342
- with self.assertRaises(CompareException) as context:
343
- task_dumppath_get(input_param)
344
- self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR)
345
- mock_error.assert_called_with("Compare is not required for overflow_check or free_benchmark.")
@@ -1,47 +0,0 @@
1
- import unittest
2
- from unittest.mock import patch, mock_open, MagicMock
3
-
4
- from msprobe.core.common.utils import Const
5
- from msprobe.core.data_dump.data_collector import DataCollector
6
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
7
- from msprobe.pytorch.pt_config import parse_json_config
8
-
9
-
10
- class TestDataCollector(unittest.TestCase):
11
- def setUp(self):
12
- mock_json_data = {
13
- "dump_path": "./ut_dump",
14
- }
15
- with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
16
- patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
17
- common_config, task_config = parse_json_config("./config.json", Const.STATISTICS)
18
- config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1")
19
- self.data_collector = DataCollector(config)
20
-
21
- def test_update_data(self):
22
- self.data_collector.config.task = Const.OVERFLOW_CHECK
23
- self.data_collector.data_processor.has_overflow = True
24
- with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
25
- result1 = self.data_collector.update_data("test message", "test1:")
26
- self.assertEqual(result1, "test1:Overflow detected.")
27
-
28
- self.data_collector.data_processor.has_overflow = False
29
- result2 = self.data_collector.update_data("test message", "test2:")
30
- self.assertEqual(result2, "test2:No Overflow, OK.")
31
-
32
- self.data_collector.config.task = Const.STATISTICS
33
- self.data_collector.data_processor.has_overflow = True
34
- with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
35
- result3 = self.data_collector.update_data("test message", "test3")
36
- self.assertEqual(result3, "test3")
37
-
38
- def test_pre_forward_data_collect(self):
39
- self.data_collector.check_scope_and_pid = MagicMock(return_value=False)
40
- self.data_collector.is_inplace = MagicMock(return_value=False)
41
- self.data_collector.data_processor.analyze_pre_forward = MagicMock()
42
- name = "TestModule.forward"
43
- pid = 123
44
-
45
- self.data_collector.pre_forward_data_collect(name, None, pid, None)
46
- self.data_collector.check_scope_and_pid.assert_called_once_with(
47
- self.data_collector.scope, "TestModule.backward", 123)
@@ -1,183 +0,0 @@
1
- import unittest
2
- from msprobe.core.data_dump.json_writer import DataWriter
3
-
4
- import os
5
- import csv
6
- from msprobe.core.common.file_check import FileOpen
7
- from msprobe.core.common import utils
8
- from pathlib import Path
9
- import json
10
-
11
- class TestDataWriter(unittest.TestCase):
12
- def test_write_data_to_csv(self):
13
- cur_path = os.path.dirname(os.path.realpath(__file__))
14
- file_path = os.path.join(cur_path, "test.csv")
15
-
16
- if os.path.exists(file_path):
17
- utils.remove_path(file_path)
18
-
19
- data = {"A":"1", "B":"2", "C":"3"}
20
- result = data.values()
21
- header = data.keys()
22
- DataWriter.write_data_to_csv(result, header, file_path)
23
- with FileOpen(file_path, "r") as f:
24
- reader = csv.DictReader(f)
25
- column_first = [row for row in reader][0]
26
- self.assertEqual(data, column_first)
27
-
28
-
29
-
30
-
31
- data = {"A":"4", "B":"5", "C":"6"}
32
- result = data.values()
33
- header = data.keys()
34
- DataWriter.write_data_to_csv(result, header, file_path)
35
- with FileOpen(file_path, "r") as f:
36
- reader = csv.DictReader(f)
37
- column_last = [row for row in reader][-1]
38
- self.assertEqual(data, column_last)
39
-
40
- utils.remove_path(file_path)
41
-
42
- def test_initialize_json_file(self):
43
- cur_path = os.path.dirname(os.path.realpath(__file__))
44
- dump_tensor_data_dir = os.path.join(cur_path, "dump_tensor_data.json")
45
- dump_file_path = os.path.join(cur_path, "dump_file.json")
46
- stack_file_path = os.path.join(cur_path, "stack_file.json")
47
- construct_file_path = os.path.join(cur_path, "construct_file.json")
48
- if not os.path.exists(stack_file_path):
49
- Path(stack_file_path).touch()
50
- if not os.path.exists(construct_file_path):
51
- Path(construct_file_path).touch()
52
-
53
- test = DataWriter()
54
- test.stack_file_path = stack_file_path
55
- test.dump_file_path = dump_file_path
56
- test.dump_tensor_data_dir = dump_tensor_data_dir
57
- test.construct_file_path = construct_file_path
58
-
59
- test.initialize_json_file()
60
-
61
- with open(dump_file_path) as f:
62
- load_data = json.load(f)
63
- result = {"dump_data_dir": dump_tensor_data_dir, "data": {}}
64
- self.assertEqual(result, load_data)
65
- is_exist_1 = os.path.exists(test.stack_file_path)
66
- self.assertTrue(is_exist_1)
67
- os.access(test.stack_file_path, os.R_OK)
68
- os.access(test.stack_file_path, os.W_OK)
69
- is_exist_2 = os.path.exists(test.construct_file_path)
70
- self.assertTrue(is_exist_2)
71
- os.access(test.construct_file_path, os.R_OK)
72
- os.access(test.construct_file_path, os.W_OK)
73
-
74
- os.remove(construct_file_path)
75
- os.remove(stack_file_path)
76
- os.remove(dump_file_path)
77
-
78
- def test_update_dump_paths(self):
79
- test = DataWriter()
80
- self.assertTrue(test.dump_file_path == None)
81
-
82
- cur_path = os.path.dirname(os.path.realpath(__file__))
83
- test_path = os.path.join(cur_path, "test1.json")
84
-
85
- test.update_dump_paths(test_path, test_path, test_path, test_path, test_path)
86
- self.assertTrue(test.dump_file_path == test_path)
87
- self.assertTrue(test.stack_file_path == test_path)
88
- self.assertTrue(test.construct_file_path == test_path)
89
- self.assertTrue(test.dump_tensor_data_dir == test_path)
90
- self.assertTrue(test.free_benchmark_file_path == test_path)
91
-
92
- def test_update_data(self):
93
- data = {"A":"1", "B":"2", "C":{"D":"2"}}
94
- test = DataWriter()
95
- test.cache_data["data"]["test_1"] = True
96
- test.cache_data["data"]["test_2"] = False
97
-
98
- test.update_data(data)
99
- self.assertEqual(test.cache_data["data"]["A"], "1")
100
-
101
- new_data = {"C":{"F":3}}
102
- test.update_data(new_data)
103
- self.assertEqual(test.cache_data["data"]["C"]["F"], 3)
104
-
105
-
106
- def test_flush_data_when_buffer_is_full_and_test_write_data_json(self):
107
- data = {"A":"1", "B":"2", "data":{}}
108
- test = DataWriter()
109
- test.buffer_size = 1
110
- test.cache_data["data"] = {"A":"1", "B":"2", "C":"3"}
111
-
112
- self.assertTrue(len(test.cache_data["data"]) >= test.buffer_size)
113
- cur_path = os.path.dirname(os.path.realpath(__file__))
114
- dump_tensor_data_dir = os.path.join(cur_path, "dump_tensor_data.json")
115
- dump_file_path = os.path.join(cur_path, "dump_file.json")
116
- stack_file_path = os.path.join(cur_path, "stack_file.json")
117
- construct_file_path = os.path.join(cur_path, "construct_file.json")
118
-
119
- test.dump_file_path = dump_file_path
120
- test.dump_tensor_data_dir = dump_tensor_data_dir
121
-
122
- with open(dump_file_path, "w") as f:
123
- dump_data = json.dumps(data)
124
- f.write(dump_data)
125
-
126
- test.flush_data_when_buffer_is_full()
127
-
128
- with open(dump_file_path, "r") as f:
129
- new_data = json.load(f)
130
-
131
- data.update({"data": {"A":"1", "B":"2", "C":"3"}})
132
- self.assertEqual(new_data, data)
133
-
134
- self.assertTrue(test.cache_data["data"] == {})
135
- os.remove(dump_file_path)
136
-
137
-
138
- def test_update_stack(self):
139
- data = {"A":"1", "B":"2", "data":{}}
140
- test = DataWriter()
141
- test.update_stack(data)
142
- self.assertEqual(test.cache_stack, data)
143
-
144
- def test_update_construct(self):
145
- data = {"A":"1", "B":"2", "data":{}}
146
- test = DataWriter()
147
- test.update_construct(data)
148
- self.assertEqual(test.cache_construct, data)
149
-
150
- def test_write_stack_info_json(self):
151
- test = DataWriter()
152
- data = {"A":"1", "B":"2", "data":{}}
153
- test.cache_stack = data
154
-
155
- cur_path = os.path.dirname(os.path.realpath(__file__))
156
- file_path = os.path.join(cur_path, "dump.json")
157
-
158
- test.write_stack_info_json(file_path)
159
-
160
- with open(file_path, "r") as f:
161
- load_result = json.load(f)
162
- try:
163
- self.assertEqual(load_result, data)
164
- finally:
165
- os.remove(file_path)
166
-
167
-
168
- def test_write_construct_info_json(self):
169
- test = DataWriter()
170
- data = {"A":"1", "B":"2", "data":{}}
171
- test.cache_construct = data
172
-
173
- cur_path = os.path.dirname(os.path.realpath(__file__))
174
- file_path = os.path.join(cur_path, "dump.json")
175
-
176
- test.write_construct_info_json(file_path)
177
-
178
- with open(file_path, "r") as f:
179
- load_result = json.load(f)
180
- try:
181
- self.assertEqual(load_result, data)
182
- finally:
183
- os.remove(file_path)