mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -1,100 +0,0 @@
1
- import math
2
- import abc
3
- import numpy as np
4
- from msprobe.core.common.utils import get_header_index
5
- from msprobe.core.common.const import CompareConst
6
-
7
-
8
- class HighlightCheck(abc.ABC):
9
- @abc.abstractmethod
10
- def apply(self, info, color_columns, summary_compare):
11
- raise NotImplementedError
12
-
13
-
14
- class CheckOrderMagnitude(HighlightCheck):
15
- """检查Max diff的数量级差异"""
16
- def apply(self, info, color_columns, summary_compare=True):
17
- api_in, api_out, num = info
18
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
19
- if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
20
- return
21
- in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
22
- out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
23
- if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
24
- color_columns.yellow.append(num)
25
-
26
-
27
- class CheckOneThousandErrorRatio(HighlightCheck):
28
- """检查千分误差比率"""
29
- def apply(self, info, color_columns, summary_compare=True):
30
- api_in, api_out, num = info
31
- one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
32
- if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
33
- return
34
- if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
35
- color_columns.red.append(num)
36
- elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
37
- color_columns.yellow.append(num)
38
-
39
-
40
- class CheckCosineSimilarity(HighlightCheck):
41
- """检查余弦相似度"""
42
- def apply(self, info, color_columns, summary_compare=True):
43
- api_in, api_out, num = info
44
- cosine_index = get_header_index('Cosine', summary_compare)
45
- if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
46
- return
47
- if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
48
- color_columns.yellow.append(num)
49
-
50
-
51
- class CheckMaxRelativeDiff(HighlightCheck):
52
- """检查最大相对差异"""
53
- def apply(self, info, color_columns, summary_compare=True):
54
- api_in, api_out, num = info
55
- max_diff_index = get_header_index('Max diff', summary_compare)
56
- bench_max_index = get_header_index('Bench max', summary_compare)
57
- input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
58
- output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
59
- if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
60
- (float, int)):
61
- return
62
- if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
63
- color_columns.red.append(num)
64
- elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
65
- color_columns.yellow.append(num)
66
-
67
-
68
- class CheckOverflow(HighlightCheck):
69
- """检查是否存在溢出"""
70
- def apply(self, info, color_columns, summary_compare=True):
71
- line, num = info
72
- npu_max_index = get_header_index('NPU max', summary_compare)
73
- npu_min_index = get_header_index('NPU min', summary_compare)
74
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
75
- if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
76
- line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
77
- color_columns.red.append(num)
78
- return
79
- # check if Max_Diff > 1e+10
80
- if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
81
- color_columns.red.append(num)
82
-
83
-
84
- class HighlightRules:
85
- """高亮规则集合,用于检查API的误差"""
86
- # 适用于每行的规则
87
- basic_rules = {
88
- "check_overflow": CheckOverflow()
89
- }
90
-
91
- # 用于比较输入和输出的规则
92
- compare_rules = {
93
- "check_order_magnitude": CheckOrderMagnitude(),
94
- "check_one_thousand_error": CheckOneThousandErrorRatio(),
95
- "check_cosine_similarity": CheckCosineSimilarity()
96
- }
97
- summary_compare_rules = {
98
- "check_order_magnitude": CheckOrderMagnitude(),
99
- "check_max_relative_diff": CheckMaxRelativeDiff(),
100
- }
@@ -1,345 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- import os
18
- import uuid
19
-
20
- from unittest import TestCase
21
- from unittest.mock import patch, MagicMock, mock_open
22
-
23
- from msprobe.core.common.log import logger
24
- from msprobe.core.common.const import Const
25
- from msprobe.core.common.utils import (CompareException,
26
- check_seed_all,
27
- check_inplace_op,
28
- make_dump_path_if_not_exists,
29
- check_mode_valid,
30
- check_switch_valid,
31
- check_dump_mode_valid,
32
- check_summary_mode_valid,
33
- check_summary_only_valid,
34
- check_file_or_directory_path,
35
- check_compare_param,
36
- check_configuration_param,
37
- is_starts_with,
38
- _check_json,
39
- check_json_file,
40
- check_file_size,
41
- check_regex_prefix_format_valid,
42
- get_dump_data_path,
43
- task_dumppath_get)
44
- from msprobe.core.common.file_check import FileCheckConst
45
-
46
-
47
- class TestUtils(TestCase):
48
- @patch.object(logger, "error")
49
- def test_check_seed_all(self, mock_error):
50
- self.assertIsNone(check_seed_all(1234, True))
51
- self.assertIsNone(check_seed_all(0, True))
52
- self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True))
53
-
54
- with self.assertRaises(CompareException) as context:
55
- check_seed_all(-1, True)
56
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
57
- mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
58
-
59
- with self.assertRaises(CompareException) as context:
60
- check_seed_all(Const.MAX_SEED_VALUE + 1, True)
61
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
62
- mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
63
-
64
- with self.assertRaises(CompareException) as context:
65
- check_seed_all("1234", True)
66
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
67
- mock_error.assert_called_with("Seed must be integer.")
68
-
69
- with self.assertRaises(CompareException) as context:
70
- check_seed_all(1234, 1)
71
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
72
- mock_error.assert_called_with("seed_all mode must be bool.")
73
-
74
- def test_check_inplace_op(self):
75
- test_prefix_1 = "Distributed.broadcast.0.forward.input.0"
76
- self.assertTrue(check_inplace_op(test_prefix_1))
77
- test_prefix_2 = "Distributed_broadcast_0_forward_input_0"
78
- self.assertFalse(check_inplace_op(test_prefix_2))
79
- test_prefix_3 = "Torch.sum.0.backward.output.0"
80
- self.assertFalse(check_inplace_op(test_prefix_3))
81
-
82
- @patch.object(logger, "error")
83
- def test_make_dump_path_if_not_exists(self, mock_error):
84
- file_path = os.path.realpath(__file__)
85
- dirname = os.path.dirname(file_path) + str(uuid.uuid4())
86
-
87
- def test_mkdir(self, **kwargs):
88
- raise OSError
89
-
90
- if not os.path.exists(dirname):
91
- with patch("msprobe.core.common.utils.Path.mkdir", new=test_mkdir):
92
- with self.assertRaises(CompareException) as context:
93
- make_dump_path_if_not_exists(dirname)
94
- self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
95
-
96
- make_dump_path_if_not_exists(file_path)
97
- mock_error.assert_called_with(f"{file_path} already exists and is not a directory.")
98
-
99
- def test_check_mode_valid(self):
100
- with self.assertRaises(ValueError) as context:
101
- check_mode_valid("all", scope="scope")
102
- self.assertEqual(str(context.exception), "scope param set invalid, it's must be a list.")
103
-
104
- with self.assertRaises(ValueError) as context:
105
- check_mode_valid("all", api_list="api_list")
106
- self.assertEqual(str(context.exception), "api_list param set invalid, it's must be a list.")
107
-
108
- mode = "all_list"
109
- with self.assertRaises(CompareException) as context:
110
- check_mode_valid(mode)
111
- self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_MODE)
112
- self.assertEqual(str(context.exception),
113
- f"Current mode '{mode}' is not supported. Please use the field in {Const.DUMP_MODE}")
114
-
115
- mode = "list"
116
- with self.assertRaises(ValueError) as context:
117
- check_mode_valid(mode)
118
- self.assertEqual(str(context.exception),
119
- "set_dump_switch, scope param set invalid, it's should not be an empty list.")
120
-
121
- @patch.object(logger, "error")
122
- def test_check_switch_valid(self, mock_error):
123
- with self.assertRaises(CompareException) as context:
124
- check_switch_valid("Close")
125
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
126
- mock_error.assert_called_with("Please set switch with 'ON' or 'OFF'.")
127
-
128
- @patch.object(logger, "warning")
129
- def test_check_dump_mode_valid(self, mock_warning):
130
- dump_mode = check_dump_mode_valid("all")
131
- mock_warning.assert_called_with("Please set dump_mode as a list.")
132
- self.assertEqual(dump_mode, ["forward", "backward", "input", "output"])
133
-
134
- with self.assertRaises(ValueError) as context:
135
- check_dump_mode_valid("all_forward")
136
- self.assertEqual(str(context.exception),
137
- "Please set dump_mode as a list containing one or more of the following: " +
138
- "'all', 'forward', 'backward', 'input', 'output'.")
139
-
140
- def test_check_summary_mode_valid(self):
141
- with self.assertRaises(CompareException) as context:
142
- check_summary_mode_valid("MD5")
143
- self.assertEqual(context.exception.code, CompareException.INVALID_SUMMARY_MODE)
144
- self.assertEqual(str(context.exception), "The summary_mode is not valid")
145
-
146
- @patch.object(logger, "error")
147
- def test_check_summary_only_valid(self, mock_error):
148
- summary_only = check_summary_only_valid(True)
149
- self.assertTrue(summary_only)
150
-
151
- with self.assertRaises(CompareException) as context:
152
- check_summary_only_valid("True")
153
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
154
- mock_error.assert_called_with("Params summary_only only support True or False.")
155
-
156
- def test_check_file_or_directory_path(self):
157
- class TestFileChecker:
158
- file_path = ""
159
- path_type = ""
160
- ability = ""
161
- checked = False
162
-
163
- def __init__(self, file_path, path_type, ability=None):
164
- TestFileChecker.file_path = file_path
165
- TestFileChecker.path_type = path_type
166
- TestFileChecker.ability = ability
167
-
168
- def common_check(self):
169
- TestFileChecker.checked = True
170
-
171
- file_path = os.path.realpath(__file__)
172
- dirname = os.path.dirname(file_path)
173
-
174
- with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
175
- check_file_or_directory_path(file_path, isdir=False)
176
- self.assertTrue(TestFileChecker.checked)
177
- self.assertEqual(TestFileChecker.file_path, file_path)
178
- self.assertEqual(TestFileChecker.path_type, FileCheckConst.FILE)
179
- self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE)
180
-
181
- TestFileChecker.checked = False
182
- with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker):
183
- check_file_or_directory_path(dirname, isdir=True)
184
- self.assertTrue(TestFileChecker.checked)
185
- self.assertEqual(TestFileChecker.file_path, dirname)
186
- self.assertEqual(TestFileChecker.path_type, FileCheckConst.DIR)
187
- self.assertEqual(TestFileChecker.ability, FileCheckConst.WRITE_ABLE)
188
-
189
- @patch.object(logger, "error")
190
- def test_check_compare_param(self, mock_error):
191
- params = {
192
- "npu_json_path": "npu_json_path",
193
- "bench_json_path": "bench_json_path",
194
- "stack_json_path": "stack_json_path",
195
- "npu_dump_data_dir": "npu_dump_data_dir",
196
- "bench_dump_data_dir": "bench_dump_data_dir"
197
- }
198
-
199
- call_args = [
200
- ("npu_json_path", False),
201
- ("bench_json_path", False),
202
- ("stack_json_path", False),
203
- ("npu_dump_data_dir", True),
204
- ("bench_dump_data_dir", True),
205
- ("output_path", True),
206
- ("npu_json_path", False),
207
- ("bench_json_path", False),
208
- ("stack_json_path", False),
209
- ("output_path", True)
210
- ]
211
-
212
- with self.assertRaises(CompareException) as context:
213
- check_compare_param("npu_json_path", "output_path")
214
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
215
- mock_error.assert_called_with("Invalid input parameters")
216
-
217
- mock_check_file_or_directory_path = MagicMock()
218
- mock_check_json_file = MagicMock()
219
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
220
- patch("msprobe.core.common.utils.check_json_file", new=mock_check_json_file), \
221
- patch("msprobe.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path):
222
- check_compare_param(params, "output_path")
223
- check_compare_param(params, "output_path", summary_compare=False, md5_compare=True)
224
- for i in range(len(call_args)):
225
- self.assertEqual(mock_check_file_or_directory_path.call_args_list[i][0], call_args[i])
226
- self.assertEqual(len(mock_check_json_file.call_args[0]), 4)
227
- self.assertEqual(mock_check_json_file.call_args[0][0], params)
228
-
229
- @patch.object(logger, "error")
230
- def test_check_configuration_param(self, mock_error):
231
- with self.assertRaises(CompareException) as context:
232
- check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False)
233
- self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
234
- mock_error.assert_called_with("Invalid input parameters which should be only bool type.")
235
-
236
- def test_is_starts_with(self):
237
- string = "input_slot0"
238
- self.assertFalse(is_starts_with(string, []))
239
- self.assertFalse(is_starts_with("", ["input"]))
240
- self.assertFalse(is_starts_with(string, ["output"]))
241
- self.assertTrue(is_starts_with(string, ["input", "output"]))
242
-
243
- @patch.object(logger, "error")
244
- def test__check_json(self, mock_error):
245
- class TestOpen:
246
- def __init__(self, string):
247
- self.string = string
248
-
249
- def readline(self):
250
- return self.string
251
-
252
- def seek(self, begin, end):
253
- self.string = str(begin) + "_" + str(end)
254
-
255
- with self.assertRaises(CompareException) as context:
256
- _check_json(TestOpen(""), "test.json")
257
- self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_FILE)
258
- mock_error.assert_called_with("dump file test.json have empty line!")
259
-
260
- handler = TestOpen("jons file\n")
261
- _check_json(handler, "test.json")
262
- self.assertEqual(handler.string, "0_0")
263
-
264
- @patch("msprobe.core.common.utils._check_json")
265
- def test_check_json_file(self, _mock_check_json):
266
- input_param = {
267
- "npu_json_path": "npu_json_path",
268
- "bench_json_path": "bench_json_path",
269
- "stack_json_path": "stack_json_path"
270
- }
271
- check_json_file(input_param, "npu_json", "bench_json", "stack_json")
272
- self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path"))
273
- self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path"))
274
- self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path"))
275
-
276
- @patch.object(logger, "error")
277
- def test_check_file_size(self, mock_error):
278
- with patch("msprobe.core.common.utils.os.path.getsize", return_value=120):
279
- with self.assertRaises(CompareException) as context:
280
- check_file_size("input_file", 100)
281
- self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR)
282
- mock_error.assert_called_with("The size (120) of input_file exceeds (100) bytes, tools not support.")
283
-
284
- def test_check_regex_prefix_format_valid(self):
285
- prefix = "A" * 21
286
- with self.assertRaises(ValueError) as context:
287
- check_regex_prefix_format_valid(prefix)
288
- self.assertEqual(str(context.exception), f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, "
289
- f"while current length is {len(prefix)}")
290
-
291
- prefix = "(prefix)"
292
- with self.assertRaises(ValueError) as context:
293
- check_regex_prefix_format_valid(prefix)
294
- self.assertEqual(str(context.exception), f"prefix contains invalid characters, "
295
- f"prefix pattern {Const.REGEX_PREFIX_PATTERN}")
296
-
297
- @patch("msprobe.core.common.utils.check_file_or_directory_path")
298
- def test_get_dump_data_path(self, mock_check_file_or_directory_path):
299
- file_path = os.path.realpath(__file__)
300
- dirname = os.path.dirname(file_path)
301
-
302
- dump_data_path, file_is_exist = get_dump_data_path(dirname)
303
- self.assertEqual(mock_check_file_or_directory_path.call_args[0], (dirname, True))
304
- self.assertEqual(dump_data_path, dirname)
305
- self.assertTrue(file_is_exist)
306
-
307
- @patch.object(logger, "error")
308
- def test_task_dumppath_get(self, mock_error):
309
- input_param = {
310
- "npu_json_path": None,
311
- "bench_json_path": "bench_json_path"
312
- }
313
- npu_json = {
314
- "task": Const.TENSOR,
315
- "dump_data_dir": "dump_data_dir",
316
- "data": "data"
317
- }
318
-
319
- with self.assertRaises(CompareException) as context:
320
- task_dumppath_get(input_param)
321
- self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
322
- mock_error.assert_called_with("Please check the json path is valid.")
323
-
324
- input_param["npu_json_path"] = "npu_json_path"
325
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
326
- patch("msprobe.core.common.utils.json.load", return_value=npu_json):
327
- summary_compare, md5_compare = task_dumppath_get(input_param)
328
- self.assertFalse(summary_compare)
329
- self.assertFalse(md5_compare)
330
-
331
- npu_json["task"] = Const.STATISTICS
332
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
333
- patch("msprobe.core.common.utils.json.load", return_value=npu_json), \
334
- patch("msprobe.core.common.utils.md5_find", return_value=True):
335
- summary_compare, md5_compare = task_dumppath_get(input_param)
336
- self.assertFalse(summary_compare)
337
- self.assertTrue(md5_compare)
338
-
339
- npu_json["task"] = Const.OVERFLOW_CHECK
340
- with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \
341
- patch("msprobe.core.common.utils.json.load", return_value=npu_json):
342
- with self.assertRaises(CompareException) as context:
343
- task_dumppath_get(input_param)
344
- self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR)
345
- mock_error.assert_called_with("Compare is not required for overflow_check or free_benchmark.")
@@ -1,47 +0,0 @@
1
- import unittest
2
- from unittest.mock import patch, mock_open, MagicMock
3
-
4
- from msprobe.core.common.utils import Const
5
- from msprobe.core.data_dump.data_collector import DataCollector
6
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
7
- from msprobe.pytorch.pt_config import parse_json_config
8
-
9
-
10
- class TestDataCollector(unittest.TestCase):
11
- def setUp(self):
12
- mock_json_data = {
13
- "dump_path": "./ut_dump",
14
- }
15
- with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \
16
- patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data):
17
- common_config, task_config = parse_json_config("./config.json", Const.STATISTICS)
18
- config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1")
19
- self.data_collector = DataCollector(config)
20
-
21
- def test_update_data(self):
22
- self.data_collector.config.task = Const.OVERFLOW_CHECK
23
- self.data_collector.data_processor.has_overflow = True
24
- with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
25
- result1 = self.data_collector.update_data("test message", "test1:")
26
- self.assertEqual(result1, "test1:Overflow detected.")
27
-
28
- self.data_collector.data_processor.has_overflow = False
29
- result2 = self.data_collector.update_data("test message", "test2:")
30
- self.assertEqual(result2, "test2:No Overflow, OK.")
31
-
32
- self.data_collector.config.task = Const.STATISTICS
33
- self.data_collector.data_processor.has_overflow = True
34
- with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None):
35
- result3 = self.data_collector.update_data("test message", "test3")
36
- self.assertEqual(result3, "test3")
37
-
38
- def test_pre_forward_data_collect(self):
39
- self.data_collector.check_scope_and_pid = MagicMock(return_value=False)
40
- self.data_collector.is_inplace = MagicMock(return_value=False)
41
- self.data_collector.data_processor.analyze_pre_forward = MagicMock()
42
- name = "TestModule.forward"
43
- pid = 123
44
-
45
- self.data_collector.pre_forward_data_collect(name, None, pid, None)
46
- self.data_collector.check_scope_and_pid.assert_called_once_with(
47
- self.data_collector.scope, "TestModule.backward", 123)
@@ -1,183 +0,0 @@
1
- import unittest
2
- from msprobe.core.data_dump.json_writer import DataWriter
3
-
4
- import os
5
- import csv
6
- from msprobe.core.common.file_check import FileOpen
7
- from msprobe.core.common import utils
8
- from pathlib import Path
9
- import json
10
-
11
- class TestDataWriter(unittest.TestCase):
12
- def test_write_data_to_csv(self):
13
- cur_path = os.path.dirname(os.path.realpath(__file__))
14
- file_path = os.path.join(cur_path, "test.csv")
15
-
16
- if os.path.exists(file_path):
17
- utils.remove_path(file_path)
18
-
19
- data = {"A":"1", "B":"2", "C":"3"}
20
- result = data.values()
21
- header = data.keys()
22
- DataWriter.write_data_to_csv(result, header, file_path)
23
- with FileOpen(file_path, "r") as f:
24
- reader = csv.DictReader(f)
25
- column_first = [row for row in reader][0]
26
- self.assertEqual(data, column_first)
27
-
28
-
29
-
30
-
31
- data = {"A":"4", "B":"5", "C":"6"}
32
- result = data.values()
33
- header = data.keys()
34
- DataWriter.write_data_to_csv(result, header, file_path)
35
- with FileOpen(file_path, "r") as f:
36
- reader = csv.DictReader(f)
37
- column_last = [row for row in reader][-1]
38
- self.assertEqual(data, column_last)
39
-
40
- utils.remove_path(file_path)
41
-
42
- def test_initialize_json_file(self):
43
- cur_path = os.path.dirname(os.path.realpath(__file__))
44
- dump_tensor_data_dir = os.path.join(cur_path, "dump_tensor_data.json")
45
- dump_file_path = os.path.join(cur_path, "dump_file.json")
46
- stack_file_path = os.path.join(cur_path, "stack_file.json")
47
- construct_file_path = os.path.join(cur_path, "construct_file.json")
48
- if not os.path.exists(stack_file_path):
49
- Path(stack_file_path).touch()
50
- if not os.path.exists(construct_file_path):
51
- Path(construct_file_path).touch()
52
-
53
- test = DataWriter()
54
- test.stack_file_path = stack_file_path
55
- test.dump_file_path = dump_file_path
56
- test.dump_tensor_data_dir = dump_tensor_data_dir
57
- test.construct_file_path = construct_file_path
58
-
59
- test.initialize_json_file()
60
-
61
- with open(dump_file_path) as f:
62
- load_data = json.load(f)
63
- result = {"dump_data_dir": dump_tensor_data_dir, "data": {}}
64
- self.assertEqual(result, load_data)
65
- is_exist_1 = os.path.exists(test.stack_file_path)
66
- self.assertTrue(is_exist_1)
67
- os.access(test.stack_file_path, os.R_OK)
68
- os.access(test.stack_file_path, os.W_OK)
69
- is_exist_2 = os.path.exists(test.construct_file_path)
70
- self.assertTrue(is_exist_2)
71
- os.access(test.construct_file_path, os.R_OK)
72
- os.access(test.construct_file_path, os.W_OK)
73
-
74
- os.remove(construct_file_path)
75
- os.remove(stack_file_path)
76
- os.remove(dump_file_path)
77
-
78
- def test_update_dump_paths(self):
79
- test = DataWriter()
80
- self.assertTrue(test.dump_file_path == None)
81
-
82
- cur_path = os.path.dirname(os.path.realpath(__file__))
83
- test_path = os.path.join(cur_path, "test1.json")
84
-
85
- test.update_dump_paths(test_path, test_path, test_path, test_path, test_path)
86
- self.assertTrue(test.dump_file_path == test_path)
87
- self.assertTrue(test.stack_file_path == test_path)
88
- self.assertTrue(test.construct_file_path == test_path)
89
- self.assertTrue(test.dump_tensor_data_dir == test_path)
90
- self.assertTrue(test.free_benchmark_file_path == test_path)
91
-
92
- def test_update_data(self):
93
- data = {"A":"1", "B":"2", "C":{"D":"2"}}
94
- test = DataWriter()
95
- test.cache_data["data"]["test_1"] = True
96
- test.cache_data["data"]["test_2"] = False
97
-
98
- test.update_data(data)
99
- self.assertEqual(test.cache_data["data"]["A"], "1")
100
-
101
- new_data = {"C":{"F":3}}
102
- test.update_data(new_data)
103
- self.assertEqual(test.cache_data["data"]["C"]["F"], 3)
104
-
105
-
106
- def test_flush_data_when_buffer_is_full_and_test_write_data_json(self):
107
- data = {"A":"1", "B":"2", "data":{}}
108
- test = DataWriter()
109
- test.buffer_size = 1
110
- test.cache_data["data"] = {"A":"1", "B":"2", "C":"3"}
111
-
112
- self.assertTrue(len(test.cache_data["data"]) >= test.buffer_size)
113
- cur_path = os.path.dirname(os.path.realpath(__file__))
114
- dump_tensor_data_dir = os.path.join(cur_path, "dump_tensor_data.json")
115
- dump_file_path = os.path.join(cur_path, "dump_file.json")
116
- stack_file_path = os.path.join(cur_path, "stack_file.json")
117
- construct_file_path = os.path.join(cur_path, "construct_file.json")
118
-
119
- test.dump_file_path = dump_file_path
120
- test.dump_tensor_data_dir = dump_tensor_data_dir
121
-
122
- with open(dump_file_path, "w") as f:
123
- dump_data = json.dumps(data)
124
- f.write(dump_data)
125
-
126
- test.flush_data_when_buffer_is_full()
127
-
128
- with open(dump_file_path, "r") as f:
129
- new_data = json.load(f)
130
-
131
- data.update({"data": {"A":"1", "B":"2", "C":"3"}})
132
- self.assertEqual(new_data, data)
133
-
134
- self.assertTrue(test.cache_data["data"] == {})
135
- os.remove(dump_file_path)
136
-
137
-
138
- def test_update_stack(self):
139
- data = {"A":"1", "B":"2", "data":{}}
140
- test = DataWriter()
141
- test.update_stack(data)
142
- self.assertEqual(test.cache_stack, data)
143
-
144
- def test_update_construct(self):
145
- data = {"A":"1", "B":"2", "data":{}}
146
- test = DataWriter()
147
- test.update_construct(data)
148
- self.assertEqual(test.cache_construct, data)
149
-
150
- def test_write_stack_info_json(self):
151
- test = DataWriter()
152
- data = {"A":"1", "B":"2", "data":{}}
153
- test.cache_stack = data
154
-
155
- cur_path = os.path.dirname(os.path.realpath(__file__))
156
- file_path = os.path.join(cur_path, "dump.json")
157
-
158
- test.write_stack_info_json(file_path)
159
-
160
- with open(file_path, "r") as f:
161
- load_result = json.load(f)
162
- try:
163
- self.assertEqual(load_result, data)
164
- finally:
165
- os.remove(file_path)
166
-
167
-
168
- def test_write_construct_info_json(self):
169
- test = DataWriter()
170
- data = {"A":"1", "B":"2", "data":{}}
171
- test.cache_construct = data
172
-
173
- cur_path = os.path.dirname(os.path.realpath(__file__))
174
- file_path = os.path.join(cur_path, "dump.json")
175
-
176
- test.write_construct_info_json(file_path)
177
-
178
- with open(file_path, "r") as f:
179
- load_result = json.load(f)
180
- try:
181
- self.assertEqual(load_result, data)
182
- finally:
183
- os.remove(file_path)