mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,116 +1,96 @@
1
- import os
2
- import csv
3
- import fcntl
4
- import json
5
- from pathlib import Path
6
-
7
- from msprobe.core.common.file_check import change_mode, FileOpen
8
- from msprobe.core.common.log import logger
9
- from msprobe.core.common.const import Const, FileCheckConst
10
-
11
-
12
- class DataWriter:
13
-
14
- def __init__(self, init_json=None) -> None:
15
- self.dump_count = 0
16
- self.init_json = init_json
17
- self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
18
- self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
19
- self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
20
- self.free_benchmark_file_path = None
21
- self.dump_tensor_data_dir = None
22
- self.buffer_size = 1000
23
- self.cache_data = {Const.DATA: {}}
24
- self.cache_stack = {}
25
- self.cache_construct = {}
26
-
27
- @staticmethod
28
- def write_data_to_csv(result: list, result_header: tuple, file_path: str):
29
- if not result:
30
- return
31
- is_exists = os.path.exists(file_path)
32
- append = "a+" if is_exists else "w+"
33
- with FileOpen(file_path, append) as csv_file:
34
- spawn_writer = csv.writer(csv_file)
35
- if not is_exists:
36
- spawn_writer.writerow(result_header)
37
- spawn_writer.writerows([result,])
38
- is_new_file = not is_exists
39
- if is_new_file:
40
- change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
41
-
42
- def initialize_json_file(self, **kwargs):
43
- kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
44
- with FileOpen(self.dump_file_path, 'w') as f:
45
- json.dump(kwargs, f)
46
- change_mode(self.dump_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
47
-
48
- if os.path.exists(self.stack_file_path):
49
- os.remove(self.stack_file_path)
50
- Path(self.stack_file_path).touch()
51
- change_mode(self.stack_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
52
-
53
- if os.path.exists(self.construct_file_path):
54
- os.remove(self.construct_file_path)
55
- Path(self.construct_file_path).touch()
56
- change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
57
-
58
- def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
59
- free_benchmark_file_path):
60
- self.dump_file_path = dump_file_path
61
- self.stack_file_path = stack_file_path
62
- self.construct_file_path = construct_file_path
63
- self.dump_tensor_data_dir = dump_data_dir
64
- self.free_benchmark_file_path = free_benchmark_file_path
65
-
66
- def update_data(self, new_data):
67
- key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
68
- if key in self.cache_data[Const.DATA]:
69
- self.cache_data[Const.DATA][key].update(new_data[key])
70
- else:
71
- self.cache_data[Const.DATA].update(new_data)
72
-
73
- def flush_data_when_buffer_is_full(self):
74
- if len(self.cache_data[Const.DATA]) >= self.buffer_size:
75
- self.write_data_json(self.dump_file_path)
76
-
77
- def update_stack(self, new_data):
78
- self.cache_stack.update(new_data)
79
-
80
- def update_construct(self, new_data):
81
- self.cache_construct.update(new_data)
82
-
83
- def write_data_json(self, file_path):
84
- logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
85
- if Path(file_path).exists() and os.path.getsize(file_path) > 0:
86
- with FileOpen(file_path, "r+") as f:
87
- fcntl.flock(f, fcntl.LOCK_EX)
88
- data_to_write = json.load(f)
89
- fcntl.flock(f, fcntl.LOCK_UN)
90
- else:
91
- self.init_json['data_path'] = self.dump_tensor_data_dir
92
- data_to_write = self.init_json
93
- data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
94
- with FileOpen(file_path, 'w+') as f:
95
- fcntl.flock(f, fcntl.LOCK_EX)
96
- json.dump(data_to_write, f, indent=1)
97
- fcntl.flock(f, fcntl.LOCK_UN)
98
-
99
- self.cache_data[Const.DATA].clear()
100
-
101
- def write_stack_info_json(self, file_path):
102
- with FileOpen(file_path, 'w+') as f:
103
- fcntl.flock(f, fcntl.LOCK_EX)
104
- json.dump(self.cache_stack, f, indent=1)
105
- fcntl.flock(f, fcntl.LOCK_UN)
106
-
107
- def write_construct_info_json(self, file_path):
108
- with FileOpen(file_path, 'w+') as f:
109
- fcntl.flock(f, fcntl.LOCK_EX)
110
- json.dump(self.cache_construct, f, indent=1)
111
- fcntl.flock(f, fcntl.LOCK_UN)
112
-
113
- def write_json(self):
114
- self.write_data_json(self.dump_file_path)
115
- self.write_stack_info_json(self.stack_file_path)
116
- self.write_construct_info_json(self.construct_file_path)
1
+ import os
2
+ import csv
3
+
4
+ from msprobe.core.common.file_utils import change_mode, FileOpen
5
+ from msprobe.core.common.log import logger
6
+ from msprobe.core.common.const import Const, FileCheckConst
7
+ from msprobe.core.common.file_utils import remove_path, load_json, save_json
8
+
9
+
10
+ class DataWriter:
11
+
12
+ def __init__(self, init_json=None) -> None:
13
+ self.dump_count = 0
14
+ self.init_json = init_json
15
+ self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
16
+ self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
17
+ self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
18
+ self.free_benchmark_file_path = None
19
+ self.dump_tensor_data_dir = None
20
+ self.buffer_size = 1000
21
+ self.cache_data = {Const.DATA: {}}
22
+ self.cache_stack = {}
23
+ self.cache_construct = {}
24
+
25
+ @staticmethod
26
+ def write_data_to_csv(result: list, result_header: tuple, file_path: str):
27
+ if not result:
28
+ return
29
+ is_exists = os.path.exists(file_path)
30
+ append = "a+" if is_exists else "w+"
31
+ with FileOpen(file_path, append) as csv_file:
32
+ spawn_writer = csv.writer(csv_file)
33
+ if not is_exists:
34
+ spawn_writer.writerow(result_header)
35
+ spawn_writer.writerows([result,])
36
+ is_new_file = not is_exists
37
+ if is_new_file:
38
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
39
+
40
+ def initialize_json_file(self, **kwargs):
41
+ kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
42
+ save_json(self.dump_file_path, kwargs)
43
+
44
+ empty_dict = {}
45
+ remove_path(self.stack_file_path)
46
+ save_json(self.stack_file_path, empty_dict)
47
+
48
+ remove_path(self.construct_file_path)
49
+ save_json(self.construct_file_path, empty_dict)
50
+
51
+ def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
52
+ free_benchmark_file_path):
53
+ self.dump_file_path = dump_file_path
54
+ self.stack_file_path = stack_file_path
55
+ self.construct_file_path = construct_file_path
56
+ self.dump_tensor_data_dir = dump_data_dir
57
+ self.free_benchmark_file_path = free_benchmark_file_path
58
+
59
+ def update_data(self, new_data):
60
+ key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
61
+ if key in self.cache_data[Const.DATA]:
62
+ self.cache_data[Const.DATA][key].update(new_data[key])
63
+ else:
64
+ self.cache_data[Const.DATA].update(new_data)
65
+
66
+ def flush_data_when_buffer_is_full(self):
67
+ if len(self.cache_data[Const.DATA]) >= self.buffer_size:
68
+ self.write_data_json(self.dump_file_path)
69
+
70
+ def update_stack(self, new_data):
71
+ self.cache_stack.update(new_data)
72
+
73
+ def update_construct(self, new_data):
74
+ self.cache_construct.update(new_data)
75
+
76
+ def write_data_json(self, file_path):
77
+ logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
78
+ if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
79
+ data_to_write = load_json(file_path)
80
+ else:
81
+ self.init_json['data_path'] = self.dump_tensor_data_dir
82
+ data_to_write = self.init_json
83
+ data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
84
+ save_json(file_path, data_to_write, indent=1)
85
+ self.cache_data[Const.DATA].clear()
86
+
87
+ def write_stack_info_json(self, file_path):
88
+ save_json(file_path, self.cache_stack, indent=1)
89
+
90
+ def write_construct_info_json(self, file_path):
91
+ save_json(file_path, self.cache_construct, indent=1)
92
+
93
+ def write_json(self):
94
+ self.write_data_json(self.dump_file_path)
95
+ self.write_stack_info_json(self.stack_file_path)
96
+ self.write_construct_info_json(self.construct_file_path)
@@ -1,178 +1,178 @@
1
- from abc import ABC, abstractmethod
2
- from msprobe.core.common.exceptions import ScopeException
3
- from msprobe.core.common.const import Const
4
-
5
-
6
- def build_scope(scope_class, scope=None, api_list=None):
7
- if not scope and not api_list:
8
- return None
9
- if scope is None:
10
- scope = []
11
- if api_list is None:
12
- api_list = []
13
- if scope_class:
14
- return scope_class(scope, api_list)
15
- return build_range_scope_according_to_scope_name(scope, api_list)
16
-
17
-
18
- def build_range_scope_according_to_scope_name(scope, api_list):
19
- api_range_scope = APIRangeScope(scope, api_list)
20
- module_range_scope = ModuleRangeScope(scope, api_list)
21
- if not scope: # 如果没有scope参数则用哪类scope都一样
22
- return api_range_scope
23
- if api_range_scope.is_valid and module_range_scope.is_valid:
24
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
25
- elif api_range_scope.is_valid:
26
- return api_range_scope
27
- elif module_range_scope.is_valid:
28
- return module_range_scope
29
- else:
30
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
31
-
32
-
33
- class BaseScope(ABC):
34
- Module_Type_Module = "Module"
35
- Module_Type_API = "api"
36
-
37
- def __init__(self, scope, api_list):
38
- scope, api_list = self.rectify_args(scope, api_list)
39
- self.scope = scope
40
- self.api_list = api_list
41
-
42
- @staticmethod
43
- def rectify_args(scope, api_list):
44
- if not isinstance(api_list, list):
45
- raise ScopeException(ScopeException.InvalidApiStr,
46
- f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
47
- for api in api_list:
48
- if not isinstance(api, str):
49
- raise ScopeException(ScopeException.InvalidApiStr,
50
- f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
51
- if isinstance(scope, str):
52
- scope = [scope]
53
- return scope, api_list
54
- if not isinstance(scope, list):
55
- raise ScopeException(ScopeException.InvalidScope,
56
- f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
57
- for s in scope:
58
- if not isinstance(s, str):
59
- raise ScopeException(ScopeException.InvalidScope,
60
- f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
61
- return scope, api_list
62
-
63
- @abstractmethod
64
- def check(self, name):
65
- pass
66
-
67
- def check_api_list(self, api_name):
68
- if not self.api_list:
69
- return True
70
- for api_str in self.api_list:
71
- if api_str in api_name:
72
- return True
73
- return False
74
-
75
-
76
- class ListScope(BaseScope):
77
- @staticmethod
78
- def rectify_args(scope, api_list):
79
- if scope and api_list:
80
- raise ScopeException(ScopeException.ArgConflict,
81
- f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
- return super(ListScope, ListScope).rectify_args(scope, api_list)
83
-
84
- def check(self, module_name):
85
- if not self.scope or module_name in self.scope:
86
- return self.check_api_list(module_name)
87
- return False
88
-
89
-
90
- class RangeScope(BaseScope, ABC):
91
-
92
- def __init__(self, *args):
93
- super().__init__(*args)
94
- self.in_scope = False
95
- self.is_valid = self.check_scope_is_valid()
96
-
97
-
98
- @staticmethod
99
- def rectify_args(scope, api_list):
100
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
101
- if isinstance(scope, list):
102
- if len(scope) == 1:
103
- scope.append(scope[0])
104
- elif len(scope) > 2:
105
- raise ScopeException(ScopeException.InvalidScope,
106
- f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
107
-
108
- return scope, api_list
109
-
110
- @abstractmethod
111
- def check_scope_is_valid(self):
112
- pass
113
-
114
- def begin_module(self, module_name):
115
- pass
116
-
117
- def end_module(self, module_name):
118
- pass
119
-
120
-
121
- class APIRangeScope(RangeScope):
122
- def check_scope_is_valid(self):
123
- if not self.scope:
124
- return True
125
- scope_start_type = self.scope[0].split(Const.SEP)[0]
126
- if scope_start_type == BaseScope.Module_Type_Module:
127
- return False
128
- scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
- if scope_stop_type == BaseScope.Module_Type_Module:
130
- return False
131
- return True
132
-
133
- def check(self, api_name):
134
- if self.scope and api_name == self.scope[0]:
135
- self.in_scope = True
136
-
137
- if not self.scope or self.in_scope:
138
- result = self.check_api_list(api_name)
139
- else:
140
- result = False
141
-
142
- if self.scope and api_name == self.scope[1]:
143
- self.in_scope = False
144
- return result
145
-
146
-
147
- class ModuleRangeScope(RangeScope):
148
- """
149
- 模块与api不同的是,模块内部还有子结构需要dump,
150
- 需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
- 在这些hook触发时调用begin_module和end_module做区间控制
152
- """
153
- def check_scope_is_valid(self):
154
- if not self.scope:
155
- return True
156
- scope_start_type = self.scope[0].split(Const.SEP)[0]
157
- scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
- if scope_start_type == BaseScope.Module_Type_Module and \
159
- scope_stop_type == BaseScope.Module_Type_Module:
160
- return True
161
- return False
162
-
163
- def begin_module(self, module_name):
164
- if not self.scope:
165
- return
166
- if module_name == self.scope[0]:
167
- self.in_scope = True
168
-
169
- def end_module(self, module_name):
170
- if not self.scope:
171
- return
172
- if module_name == self.scope[1]:
173
- self.in_scope = False
174
-
175
- def check(self, module_name):
176
- if not self.scope or self.in_scope:
177
- return self.check_api_list(module_name)
178
- return False
1
+ from abc import ABC, abstractmethod
2
+ from msprobe.core.common.exceptions import ScopeException
3
+ from msprobe.core.common.const import Const
4
+
5
+
6
+ def build_scope(scope_class, scope=None, api_list=None):
7
+ if not scope and not api_list:
8
+ return None
9
+ if scope is None:
10
+ scope = []
11
+ if api_list is None:
12
+ api_list = []
13
+ if scope_class:
14
+ return scope_class(scope, api_list)
15
+ return build_range_scope_according_to_scope_name(scope, api_list)
16
+
17
+
18
+ def build_range_scope_according_to_scope_name(scope, api_list):
19
+ api_range_scope = APIRangeScope(scope, api_list)
20
+ module_range_scope = ModuleRangeScope(scope, api_list)
21
+ if not scope: # 如果没有scope参数则用哪类scope都一样
22
+ return api_range_scope
23
+ if api_range_scope.is_valid and module_range_scope.is_valid:
24
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
25
+ elif api_range_scope.is_valid:
26
+ return api_range_scope
27
+ elif module_range_scope.is_valid:
28
+ return module_range_scope
29
+ else:
30
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
31
+
32
+
33
+ class BaseScope(ABC):
34
+ Module_Type_Module = "Module"
35
+ Module_Type_API = "api"
36
+
37
+ def __init__(self, scope, api_list):
38
+ scope, api_list = self.rectify_args(scope, api_list)
39
+ self.scope = scope
40
+ self.api_list = api_list
41
+
42
+ @staticmethod
43
+ def rectify_args(scope, api_list):
44
+ if not isinstance(api_list, list):
45
+ raise ScopeException(ScopeException.InvalidApiStr,
46
+ f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
47
+ for api in api_list:
48
+ if not isinstance(api, str):
49
+ raise ScopeException(ScopeException.InvalidApiStr,
50
+ f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
51
+ if isinstance(scope, str):
52
+ scope = [scope]
53
+ return scope, api_list
54
+ if not isinstance(scope, list):
55
+ raise ScopeException(ScopeException.InvalidScope,
56
+ f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
57
+ for s in scope:
58
+ if not isinstance(s, str):
59
+ raise ScopeException(ScopeException.InvalidScope,
60
+ f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
61
+ return scope, api_list
62
+
63
+ @abstractmethod
64
+ def check(self, name):
65
+ pass
66
+
67
+ def check_api_list(self, api_name):
68
+ if not self.api_list:
69
+ return True
70
+ for api_str in self.api_list:
71
+ if api_str in api_name:
72
+ return True
73
+ return False
74
+
75
+
76
+ class ListScope(BaseScope):
77
+ @staticmethod
78
+ def rectify_args(scope, api_list):
79
+ if scope and api_list:
80
+ raise ScopeException(ScopeException.ArgConflict,
81
+ f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
+ return super(ListScope, ListScope).rectify_args(scope, api_list)
83
+
84
+ def check(self, module_name):
85
+ if not self.scope or module_name in self.scope:
86
+ return self.check_api_list(module_name)
87
+ return False
88
+
89
+
90
+ class RangeScope(BaseScope, ABC):
91
+
92
+ def __init__(self, *args):
93
+ super().__init__(*args)
94
+ self.in_scope = False
95
+ self.is_valid = self.check_scope_is_valid()
96
+
97
+
98
+ @staticmethod
99
+ def rectify_args(scope, api_list):
100
+ scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
101
+ if isinstance(scope, list):
102
+ if len(scope) == 1:
103
+ scope.append(scope[0])
104
+ elif len(scope) > 2:
105
+ raise ScopeException(ScopeException.InvalidScope,
106
+ f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
107
+
108
+ return scope, api_list
109
+
110
+ @abstractmethod
111
+ def check_scope_is_valid(self):
112
+ pass
113
+
114
+ def begin_module(self, module_name):
115
+ pass
116
+
117
+ def end_module(self, module_name):
118
+ pass
119
+
120
+
121
+ class APIRangeScope(RangeScope):
122
+ def check_scope_is_valid(self):
123
+ if not self.scope:
124
+ return True
125
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
126
+ if scope_start_type == BaseScope.Module_Type_Module:
127
+ return False
128
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
+ if scope_stop_type == BaseScope.Module_Type_Module:
130
+ return False
131
+ return True
132
+
133
+ def check(self, api_name):
134
+ if self.scope and api_name == self.scope[0]:
135
+ self.in_scope = True
136
+
137
+ if not self.scope or self.in_scope:
138
+ result = self.check_api_list(api_name)
139
+ else:
140
+ result = False
141
+
142
+ if self.scope and api_name == self.scope[1]:
143
+ self.in_scope = False
144
+ return result
145
+
146
+
147
+ class ModuleRangeScope(RangeScope):
148
+ """
149
+ 模块与api不同的是,模块内部还有子结构需要dump,
150
+ 需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
+ 在这些hook触发时调用begin_module和end_module做区间控制
152
+ """
153
+ def check_scope_is_valid(self):
154
+ if not self.scope:
155
+ return True
156
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
157
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
+ if scope_start_type == BaseScope.Module_Type_Module and \
159
+ scope_stop_type == BaseScope.Module_Type_Module:
160
+ return True
161
+ return False
162
+
163
+ def begin_module(self, module_name):
164
+ if not self.scope:
165
+ return
166
+ if module_name == self.scope[0]:
167
+ self.in_scope = True
168
+
169
+ def end_module(self, module_name):
170
+ if not self.scope:
171
+ return
172
+ if module_name == self.scope[1]:
173
+ self.in_scope = False
174
+
175
+ def check(self, module_name):
176
+ if not self.scope or self.in_scope:
177
+ return self.check_api_list(module_name)
178
+ return False