mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,116 +1,117 @@
1
- import os
2
- import csv
3
- import fcntl
4
- import json
5
- from pathlib import Path
6
-
7
- from msprobe.core.common.file_check import change_mode, FileOpen
8
- from msprobe.core.common.log import logger
9
- from msprobe.core.common.const import Const, FileCheckConst
10
-
11
-
12
- class DataWriter:
13
-
14
- def __init__(self, init_json=None) -> None:
15
- self.dump_count = 0
16
- self.init_json = init_json
17
- self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
18
- self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
19
- self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
20
- self.free_benchmark_file_path = None
21
- self.dump_tensor_data_dir = None
22
- self.buffer_size = 1000
23
- self.cache_data = {Const.DATA: {}}
24
- self.cache_stack = {}
25
- self.cache_construct = {}
26
-
27
- @staticmethod
28
- def write_data_to_csv(result: list, result_header: tuple, file_path: str):
29
- if not result:
30
- return
31
- is_exists = os.path.exists(file_path)
32
- append = "a+" if is_exists else "w+"
33
- with FileOpen(file_path, append) as csv_file:
34
- spawn_writer = csv.writer(csv_file)
35
- if not is_exists:
36
- spawn_writer.writerow(result_header)
37
- spawn_writer.writerows([result,])
38
- is_new_file = not is_exists
39
- if is_new_file:
40
- change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
41
-
42
- def initialize_json_file(self, **kwargs):
43
- kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
44
- with FileOpen(self.dump_file_path, 'w') as f:
45
- json.dump(kwargs, f)
46
- change_mode(self.dump_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
47
-
48
- if os.path.exists(self.stack_file_path):
49
- os.remove(self.stack_file_path)
50
- Path(self.stack_file_path).touch()
51
- change_mode(self.stack_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
52
-
53
- if os.path.exists(self.construct_file_path):
54
- os.remove(self.construct_file_path)
55
- Path(self.construct_file_path).touch()
56
- change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
57
-
58
- def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
59
- free_benchmark_file_path):
60
- self.dump_file_path = dump_file_path
61
- self.stack_file_path = stack_file_path
62
- self.construct_file_path = construct_file_path
63
- self.dump_tensor_data_dir = dump_data_dir
64
- self.free_benchmark_file_path = free_benchmark_file_path
65
-
66
- def update_data(self, new_data):
67
- key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
68
- if key in self.cache_data[Const.DATA]:
69
- self.cache_data[Const.DATA][key].update(new_data[key])
70
- else:
71
- self.cache_data[Const.DATA].update(new_data)
72
-
73
- def flush_data_when_buffer_is_full(self):
74
- if len(self.cache_data[Const.DATA]) >= self.buffer_size:
75
- self.write_data_json(self.dump_file_path)
76
-
77
- def update_stack(self, new_data):
78
- self.cache_stack.update(new_data)
79
-
80
- def update_construct(self, new_data):
81
- self.cache_construct.update(new_data)
82
-
83
- def write_data_json(self, file_path):
84
- logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
85
- if Path(file_path).exists() and os.path.getsize(file_path) > 0:
86
- with FileOpen(file_path, "r+") as f:
87
- fcntl.flock(f, fcntl.LOCK_EX)
88
- data_to_write = json.load(f)
89
- fcntl.flock(f, fcntl.LOCK_UN)
90
- else:
91
- self.init_json['data_path'] = self.dump_tensor_data_dir
92
- data_to_write = self.init_json
93
- data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
94
- with FileOpen(file_path, 'w+') as f:
95
- fcntl.flock(f, fcntl.LOCK_EX)
96
- json.dump(data_to_write, f, indent=1)
97
- fcntl.flock(f, fcntl.LOCK_UN)
98
-
99
- self.cache_data[Const.DATA].clear()
100
-
101
- def write_stack_info_json(self, file_path):
102
- with FileOpen(file_path, 'w+') as f:
103
- fcntl.flock(f, fcntl.LOCK_EX)
104
- json.dump(self.cache_stack, f, indent=1)
105
- fcntl.flock(f, fcntl.LOCK_UN)
106
-
107
- def write_construct_info_json(self, file_path):
108
- with FileOpen(file_path, 'w+') as f:
109
- fcntl.flock(f, fcntl.LOCK_EX)
110
- json.dump(self.cache_construct, f, indent=1)
111
- fcntl.flock(f, fcntl.LOCK_UN)
112
-
113
- def write_json(self):
114
- self.write_data_json(self.dump_file_path)
115
- self.write_stack_info_json(self.stack_file_path)
116
- self.write_construct_info_json(self.construct_file_path)
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import csv
17
+ import os
18
+
19
+ from msprobe.core.common.const import Const, FileCheckConst
20
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
21
+ from msprobe.core.common.log import logger
22
+
23
+
24
+ class DataWriter:
25
+
26
+ def __init__(self) -> None:
27
+ self.dump_file_path = None
28
+ self.stack_file_path = None
29
+ self.construct_file_path = None
30
+ self.free_benchmark_file_path = None
31
+ self.dump_tensor_data_dir = None
32
+ self.flush_size = 1000
33
+ self.cache_data = {}
34
+ self.cache_stack = {}
35
+ self.cache_construct = {}
36
+
37
+ @staticmethod
38
+ def write_data_to_csv(result: list, result_header: tuple, file_path: str):
39
+ if not result:
40
+ return
41
+ is_exists = os.path.exists(file_path)
42
+ append = "a+" if is_exists else "w+"
43
+ with FileOpen(file_path, append) as csv_file:
44
+ spawn_writer = csv.writer(csv_file)
45
+ if not is_exists:
46
+ spawn_writer.writerow(result_header)
47
+ spawn_writer.writerows([result,])
48
+ is_new_file = not is_exists
49
+ if is_new_file:
50
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
51
+
52
+ def reset_cache(self):
53
+ self.cache_data = {}
54
+ self.cache_stack = {}
55
+ self.cache_construct = {}
56
+
57
+ def initialize_json_file(self, **kwargs):
58
+ if not self.cache_data:
59
+ kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
60
+ self.cache_data = kwargs
61
+ save_json(self.dump_file_path, self.cache_data, indent=1)
62
+ if not self.cache_stack:
63
+ save_json(self.stack_file_path, self.cache_stack, indent=1)
64
+ if not self.cache_construct:
65
+ save_json(self.construct_file_path, self.cache_construct, indent=1)
66
+
67
+ def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
68
+ free_benchmark_file_path):
69
+ self.dump_file_path = dump_file_path
70
+ self.stack_file_path = stack_file_path
71
+ self.construct_file_path = construct_file_path
72
+ self.dump_tensor_data_dir = dump_data_dir
73
+ self.free_benchmark_file_path = free_benchmark_file_path
74
+
75
+ def flush_data_periodically(self):
76
+ dump_data = self.cache_data.get(Const.DATA)
77
+ if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0:
78
+ self.write_json()
79
+
80
+ def update_data(self, new_data):
81
+ if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
82
+ logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
83
+ return
84
+ dump_data = self.cache_data.get(Const.DATA)
85
+ if not isinstance(dump_data, dict):
86
+ logger.warning(f"The dump data({dump_data}) should be a dict.")
87
+ return
88
+
89
+ key = next(iter(new_data.keys()))
90
+ if key in dump_data:
91
+ dump_data.get(key).update(new_data.get(key))
92
+ else:
93
+ dump_data.update(new_data)
94
+
95
+ def update_stack(self, new_data):
96
+ self.cache_stack.update(new_data)
97
+
98
+ def update_construct(self, new_data):
99
+ self.cache_construct.update(new_data)
100
+
101
+ def write_data_json(self, file_path):
102
+ logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
103
+ save_json(file_path, self.cache_data, indent=1)
104
+
105
+ def write_stack_info_json(self, file_path):
106
+ save_json(file_path, self.cache_stack, indent=1)
107
+
108
+ def write_construct_info_json(self, file_path):
109
+ save_json(file_path, self.cache_construct, indent=1)
110
+
111
+ def write_json(self):
112
+ if self.cache_data:
113
+ self.write_data_json(self.dump_file_path)
114
+ if self.cache_stack:
115
+ self.write_stack_info_json(self.stack_file_path)
116
+ if self.cache_construct:
117
+ self.write_construct_info_json(self.construct_file_path)
@@ -1,178 +1,194 @@
1
- from abc import ABC, abstractmethod
2
- from msprobe.core.common.exceptions import ScopeException
3
- from msprobe.core.common.const import Const
4
-
5
-
6
- def build_scope(scope_class, scope=None, api_list=None):
7
- if not scope and not api_list:
8
- return None
9
- if scope is None:
10
- scope = []
11
- if api_list is None:
12
- api_list = []
13
- if scope_class:
14
- return scope_class(scope, api_list)
15
- return build_range_scope_according_to_scope_name(scope, api_list)
16
-
17
-
18
- def build_range_scope_according_to_scope_name(scope, api_list):
19
- api_range_scope = APIRangeScope(scope, api_list)
20
- module_range_scope = ModuleRangeScope(scope, api_list)
21
- if not scope: # 如果没有scope参数则用哪类scope都一样
22
- return api_range_scope
23
- if api_range_scope.is_valid and module_range_scope.is_valid:
24
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
25
- elif api_range_scope.is_valid:
26
- return api_range_scope
27
- elif module_range_scope.is_valid:
28
- return module_range_scope
29
- else:
30
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
31
-
32
-
33
- class BaseScope(ABC):
34
- Module_Type_Module = "Module"
35
- Module_Type_API = "api"
36
-
37
- def __init__(self, scope, api_list):
38
- scope, api_list = self.rectify_args(scope, api_list)
39
- self.scope = scope
40
- self.api_list = api_list
41
-
42
- @staticmethod
43
- def rectify_args(scope, api_list):
44
- if not isinstance(api_list, list):
45
- raise ScopeException(ScopeException.InvalidApiStr,
46
- f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
47
- for api in api_list:
48
- if not isinstance(api, str):
49
- raise ScopeException(ScopeException.InvalidApiStr,
50
- f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
51
- if isinstance(scope, str):
52
- scope = [scope]
53
- return scope, api_list
54
- if not isinstance(scope, list):
55
- raise ScopeException(ScopeException.InvalidScope,
56
- f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
57
- for s in scope:
58
- if not isinstance(s, str):
59
- raise ScopeException(ScopeException.InvalidScope,
60
- f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
61
- return scope, api_list
62
-
63
- @abstractmethod
64
- def check(self, name):
65
- pass
66
-
67
- def check_api_list(self, api_name):
68
- if not self.api_list:
69
- return True
70
- for api_str in self.api_list:
71
- if api_str in api_name:
72
- return True
73
- return False
74
-
75
-
76
- class ListScope(BaseScope):
77
- @staticmethod
78
- def rectify_args(scope, api_list):
79
- if scope and api_list:
80
- raise ScopeException(ScopeException.ArgConflict,
81
- f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
- return super(ListScope, ListScope).rectify_args(scope, api_list)
83
-
84
- def check(self, module_name):
85
- if not self.scope or module_name in self.scope:
86
- return self.check_api_list(module_name)
87
- return False
88
-
89
-
90
- class RangeScope(BaseScope, ABC):
91
-
92
- def __init__(self, *args):
93
- super().__init__(*args)
94
- self.in_scope = False
95
- self.is_valid = self.check_scope_is_valid()
96
-
97
-
98
- @staticmethod
99
- def rectify_args(scope, api_list):
100
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
101
- if isinstance(scope, list):
102
- if len(scope) == 1:
103
- scope.append(scope[0])
104
- elif len(scope) > 2:
105
- raise ScopeException(ScopeException.InvalidScope,
106
- f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
107
-
108
- return scope, api_list
109
-
110
- @abstractmethod
111
- def check_scope_is_valid(self):
112
- pass
113
-
114
- def begin_module(self, module_name):
115
- pass
116
-
117
- def end_module(self, module_name):
118
- pass
119
-
120
-
121
- class APIRangeScope(RangeScope):
122
- def check_scope_is_valid(self):
123
- if not self.scope:
124
- return True
125
- scope_start_type = self.scope[0].split(Const.SEP)[0]
126
- if scope_start_type == BaseScope.Module_Type_Module:
127
- return False
128
- scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
- if scope_stop_type == BaseScope.Module_Type_Module:
130
- return False
131
- return True
132
-
133
- def check(self, api_name):
134
- if self.scope and api_name == self.scope[0]:
135
- self.in_scope = True
136
-
137
- if not self.scope or self.in_scope:
138
- result = self.check_api_list(api_name)
139
- else:
140
- result = False
141
-
142
- if self.scope and api_name == self.scope[1]:
143
- self.in_scope = False
144
- return result
145
-
146
-
147
- class ModuleRangeScope(RangeScope):
148
- """
149
- 模块与api不同的是,模块内部还有子结构需要dump,
150
- 需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
- 在这些hook触发时调用begin_module和end_module做区间控制
152
- """
153
- def check_scope_is_valid(self):
154
- if not self.scope:
155
- return True
156
- scope_start_type = self.scope[0].split(Const.SEP)[0]
157
- scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
- if scope_start_type == BaseScope.Module_Type_Module and \
159
- scope_stop_type == BaseScope.Module_Type_Module:
160
- return True
161
- return False
162
-
163
- def begin_module(self, module_name):
164
- if not self.scope:
165
- return
166
- if module_name == self.scope[0]:
167
- self.in_scope = True
168
-
169
- def end_module(self, module_name):
170
- if not self.scope:
171
- return
172
- if module_name == self.scope[1]:
173
- self.in_scope = False
174
-
175
- def check(self, module_name):
176
- if not self.scope or self.in_scope:
177
- return self.check_api_list(module_name)
178
- return False
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import ScopeException
20
+
21
+
22
+ def build_scope(scope_class, scope=None, api_list=None):
23
+ if not scope and not api_list:
24
+ return None
25
+ if scope is None:
26
+ scope = []
27
+ if api_list is None:
28
+ api_list = []
29
+ if scope_class:
30
+ return scope_class(scope, api_list)
31
+ return build_range_scope_according_to_scope_name(scope, api_list)
32
+
33
+
34
+ def build_range_scope_according_to_scope_name(scope, api_list):
35
+ api_range_scope = APIRangeScope(scope, api_list)
36
+ module_range_scope = ModuleRangeScope(scope, api_list)
37
+ if not scope: # 如果没有scope参数则用哪类scope都一样
38
+ return api_range_scope
39
+ if api_range_scope.is_valid and module_range_scope.is_valid:
40
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
41
+ elif api_range_scope.is_valid:
42
+ return api_range_scope
43
+ elif module_range_scope.is_valid:
44
+ return module_range_scope
45
+ else:
46
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
47
+
48
+
49
+ class BaseScope(ABC):
50
+ Module_Type_Module = "Module"
51
+ Module_Type_API = "api"
52
+ module_type = ["Module", "Cell"]
53
+
54
+ def __init__(self, scope, api_list):
55
+ scope, api_list = self.rectify_args(scope, api_list)
56
+ self.scope = scope
57
+ self.api_list = api_list
58
+
59
+ @staticmethod
60
+ def rectify_args(scope, api_list):
61
+ if not isinstance(api_list, list):
62
+ raise ScopeException(ScopeException.InvalidApiStr,
63
+ f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
64
+ for api in api_list:
65
+ if not isinstance(api, str):
66
+ raise ScopeException(ScopeException.InvalidApiStr,
67
+ f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
68
+ if isinstance(scope, str):
69
+ scope = [scope]
70
+ return scope, api_list
71
+ if not isinstance(scope, list):
72
+ raise ScopeException(ScopeException.InvalidScope,
73
+ f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
74
+ for s in scope:
75
+ if not isinstance(s, str):
76
+ raise ScopeException(ScopeException.InvalidScope,
77
+ f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
78
+ return scope, api_list
79
+
80
+ @abstractmethod
81
+ def check(self, name):
82
+ pass
83
+
84
+ def check_api_list(self, api_name):
85
+ if not self.api_list:
86
+ return True
87
+ for api_str in self.api_list:
88
+ if api_str in api_name:
89
+ return True
90
+ return False
91
+
92
+
93
+ class ListScope(BaseScope):
94
+ @staticmethod
95
+ def rectify_args(scope, api_list):
96
+ if scope and api_list:
97
+ raise ScopeException(ScopeException.ArgConflict,
98
+ f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
99
+ return super(ListScope, ListScope).rectify_args(scope, api_list)
100
+
101
+ def check(self, name):
102
+ if not self.scope or name in self.scope:
103
+ return self.check_api_list(name)
104
+ return False
105
+
106
+
107
+ class RangeScope(BaseScope, ABC):
108
+
109
+ def __init__(self, *args):
110
+ super().__init__(*args)
111
+ self.in_scope = False
112
+ self.is_valid = self.check_scope_is_valid()
113
+
114
+ @staticmethod
115
+ def rectify_args(scope, api_list):
116
+ scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
117
+ if isinstance(scope, list):
118
+ if len(scope) == 1:
119
+ scope.append(scope[0])
120
+ elif len(scope) > 2:
121
+ raise ScopeException(ScopeException.InvalidScope,
122
+ f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
123
+ return scope, api_list
124
+
125
+ @abstractmethod
126
+ def check_scope_is_valid(self):
127
+ pass
128
+
129
+ def begin_module(self, module_name):
130
+ pass
131
+
132
+ def end_module(self, module_name):
133
+ pass
134
+
135
+
136
+ class APIRangeScope(RangeScope):
137
+ def check_scope_is_valid(self):
138
+ if not self.scope:
139
+ return True
140
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
141
+ if scope_start_type in BaseScope.module_type:
142
+ return False
143
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
144
+ if scope_stop_type in BaseScope.module_type:
145
+ return False
146
+ return True
147
+
148
+ def check(self, name):
149
+ if self.scope and name == self.scope[0]:
150
+ self.in_scope = True
151
+
152
+ if not self.scope or self.in_scope:
153
+ result = self.check_api_list(name)
154
+ else:
155
+ result = False
156
+
157
+ if self.scope and name == self.scope[1]:
158
+ self.in_scope = False
159
+ return result
160
+
161
+
162
+ class ModuleRangeScope(RangeScope):
163
+ """
164
+ 模块与api不同的是,模块内部还有子结构需要dump,
165
+ 需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
166
+ 在这些hook触发时调用begin_module和end_module做区间控制
167
+ """
168
+
169
+ def check_scope_is_valid(self):
170
+ if not self.scope:
171
+ return True
172
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
173
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
174
+ if scope_start_type in BaseScope.module_type and \
175
+ scope_stop_type in BaseScope.module_type:
176
+ return True
177
+ return False
178
+
179
+ def begin_module(self, module_name):
180
+ if not self.scope:
181
+ return
182
+ if module_name == self.scope[0]:
183
+ self.in_scope = True
184
+
185
+ def end_module(self, module_name):
186
+ if not self.scope:
187
+ return
188
+ if module_name == self.scope[1]:
189
+ self.in_scope = False
190
+
191
+ def check(self, name):
192
+ if not self.scope or self.in_scope:
193
+ return self.check_api_list(name)
194
+ return False