mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,8 +1,24 @@
1
- from msprobe.core.common.exceptions import ApiAccuracyCheckerException
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from msprobe.core.common.const import Const
17
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
3
18
  from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list
4
19
  from msprobe.mindspore.common.log import logger
5
20
 
21
+
6
22
  def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
7
23
  '''
8
24
  Args:
@@ -22,30 +38,30 @@ def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_t
22
38
  3. value is not accepted type
23
39
  4. value is not accepted value
24
40
  '''
25
- parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
26
41
  if not isinstance(dict_instance, dict):
27
- logger.error_log_with_exp("check_and_get_from_json_dict failed: input is not a dict", parse_failed_exception)
42
+ error_info = "check_and_get_from_json_dict failed: input is not a dict"
43
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
28
44
  value = dict_instance.get(key)
29
45
  if value is None:
30
- logger.error_log_with_exp(f"check_and_get_from_json_dict failed: {key_description} is missing",
31
- parse_failed_exception)
46
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is missing"
47
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
32
48
  elif accepted_type is not None and not isinstance(value, accepted_type):
33
- logger.error_log_with_exp(
34
- f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}",
35
- parse_failed_exception)
49
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}"
50
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
36
51
  elif accepted_value is not None and value not in accepted_value:
37
- logger.error_log_with_exp(
38
- f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}",
39
- parse_failed_exception)
52
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}"
53
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
40
54
  return value
41
55
 
42
- def convert_to_tuple(input):
43
- if isinstance(input, (tuple, list)):
44
- return tuple(input)
56
+
57
+ def convert_to_tuple(args):
58
+ if isinstance(args, (tuple, list)):
59
+ return tuple(args)
45
60
  else:
46
- input_list = [input]
61
+ input_list = [args]
47
62
  return tuple(input_list)
48
63
 
64
+
49
65
  def trim_output_compute_element_list(compute_element_list, forward_or_backward):
50
66
  '''
51
67
  Args:
@@ -55,12 +71,13 @@ def trim_output_compute_element_list(compute_element_list, forward_or_backward):
55
71
  trimmed_list = []
56
72
  for compute_element in compute_element_list:
57
73
  if compute_element.get_parameter() is None or \
58
- (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
74
+ (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
59
75
  # trim case: 1. parameter is None. 2. backward output has non float parameter
60
76
  continue
61
77
  trimmed_list.append(compute_element)
62
78
  return trimmed_list
63
79
 
80
+
64
81
  class GlobalContext:
65
82
  def __init__(self):
66
83
  self.is_constructed = True
@@ -77,4 +94,4 @@ class GlobalContext:
77
94
  return self.is_constructed
78
95
 
79
96
 
80
- global_context = GlobalContext()
97
+ global_context = GlobalContext()
@@ -1,17 +1,31 @@
1
- from msprobe.core.data_dump.scope import ModuleRangeScope
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
2
17
  from msprobe.core.common.const import Const
3
- from msprobe.mindspore.common.log import logger
4
18
 
5
19
 
6
20
  class CellProcessor:
7
21
  cell_count = {}
22
+ cell_stack = []
23
+ api_parent_node = ""
24
+ module_node = {}
8
25
 
9
26
  def __init__(self, scope):
10
- if isinstance(scope, ModuleRangeScope):
11
- self.scope = scope
12
- else:
13
- self.scope = None
14
-
27
+ self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
28
+
15
29
  @staticmethod
16
30
  def set_cell_count(cell_name):
17
31
  if cell_name not in CellProcessor.cell_count:
@@ -19,16 +33,47 @@ class CellProcessor:
19
33
  else:
20
34
  CellProcessor.cell_count[cell_name] += 1
21
35
  return CellProcessor.cell_count[cell_name]
22
-
36
+
37
+ @classmethod
38
+ def reset_cell_stats(cls):
39
+ cls.cell_count = {}
40
+ cls.cell_stack = []
41
+ cls.api_parent_node = ""
42
+ cls.module_node = {}
43
+
23
44
  def node_hook(self, name_prefix, start_or_stop, **kwargs):
24
- def begin_hook(cell, input):
25
- index = self.set_cell_count(name_prefix)
26
- cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
45
+ def begin_hook(cell, input_data):
46
+ full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True)
47
+ if CellProcessor.cell_stack:
48
+ CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1]
49
+ else:
50
+ CellProcessor.module_node[full_name] = None
51
+
52
+ CellProcessor.cell_stack.append(full_name)
53
+ CellProcessor.api_parent_node = full_name
54
+
27
55
  if self.scope:
28
56
  self.scope.begin_module(full_name)
29
-
30
- def end_hook(cell, input, output):
57
+
58
+ def end_hook(cell, input_data, output_data):
59
+ if CellProcessor.cell_stack:
60
+ CellProcessor.cell_stack.pop()
61
+ if CellProcessor.cell_stack:
62
+ CellProcessor.api_parent_node = CellProcessor.cell_stack[-1]
63
+ else:
64
+ CellProcessor.api_parent_node = None
65
+
31
66
  if self.scope:
32
67
  self.scope.end_module(cell.mindstudio_reserved_name)
33
68
 
34
69
  return begin_hook if Const.START == start_or_stop else end_hook
70
+
71
+ def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False):
72
+ if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called:
73
+ cell.has_pre_hook_called = False
74
+ else:
75
+ if is_called_by_pre_hook:
76
+ cell.has_pre_hook_called = True
77
+ index = self.set_cell_count(cell_name)
78
+ cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index)
79
+ return cell.mindstudio_reserved_name
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import numpy as np
2
17
  import mindspore as ms
3
18
 
@@ -23,31 +38,35 @@ class Const:
23
38
  ASCEND_910A = "ascend910"
24
39
 
25
40
  OPS_PREFIX = "mindspore.ops."
26
- Tensor_PREFIX = "mindspore.Tensor."
41
+ TENSOR_PREFIX = "mindspore.Tensor."
27
42
  MINT_PREFIX = "mindspore.mint."
28
43
  MINT_NN_FUNC_PREFIX = "mindspore.mint.nn.functional."
29
- COMM_PREFIX = "mindspore.communication.comm_func."
30
- COMMUNICATION_API_LIST = [
31
- "mindspore.communication.comm_func.all_gather_into_tensor",
32
- "mindspore.communication.comm_func.gather_into_tensor",
33
- "mindspore.communication.comm_func.all_reduce",
34
- "mindspore.communication.comm_func.reduce",
35
- "mindspore.communication.comm_func.reduce_scatter_tensor"
36
- ]
44
+
37
45
  TENSOR_DATA_PREFIX = "Tensor."
38
46
  STUB_TENSOR_DATA_PREFIX = "Tensor."
39
47
  OPS_DATA_PREFIX = "Functional."
40
48
  MINT_DATA_PREFIX = "Mint."
41
49
  MINT_NN_FUNC_DATA_PREFIX = "MintFunctional."
50
+ DISTRIBUTED_DATA_PREFIX = "Distributed."
42
51
 
43
52
  SUPPORTED_API_LIST_FILE = "support_wrap_ops.yaml"
44
53
  SUPPORTED_TENSOR_LIST_KEY = "tensor"
45
54
  SUPPORTED_OPS_LIST_KEY = "ops"
46
55
  SUPPORTED_MINT_LIST_KEY = "mint.ops"
47
56
  SUPPORTED__MINT_NN_FUNC_LIST_KEY = "mint.nn.functional"
57
+ SUPPORTED_COMM_LIST_KEY = "communication.comm_func"
48
58
 
49
59
  DROPOUT_API_NAME_PREFIX = "dropout"
50
60
 
61
+ GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT]
62
+
63
+ HOOK_MS_PREFIX_DICT = {
64
+ OPS_DATA_PREFIX: OPS_PREFIX,
65
+ TENSOR_DATA_PREFIX: TENSOR_PREFIX,
66
+ MINT_DATA_PREFIX: MINT_PREFIX,
67
+ MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX
68
+ }
69
+
51
70
 
52
71
  class FreeBenchmarkConst:
53
72
  ADD_NOISE = "add_noise"
@@ -63,19 +82,21 @@ class FreeBenchmarkConst:
63
82
  DEFAULT_PERT_TYPE = IMPROVE_PRECISION
64
83
  DEFAULT_HANDLER_TYPE = CHECK
65
84
  DEVICE_LIST = [DEFAULT_DEVICE]
66
- STAGE_LIST = [CoreConst.FORWARD]
85
+ STAGE_LIST = [CoreConst.FORWARD, CoreConst.BACKWARD]
67
86
  DUMP_LEVEL_LIST = [DEFAULT_DUMP_LEVEL]
68
87
  PERT_TYPE_LIST = [IMPROVE_PRECISION, ADD_NOISE, BIT_NOISE, NO_CHANGE, EXCHANGE_VALUE]
69
88
  HANDLER_TYPE_LIST = [CHECK, FIX]
70
89
  NO_CHANGE_ERROR_THRESHOLD = 1.0
71
90
  SYMBOL_FLIPPING_RATIO = 8.0
72
91
 
92
+ SUPPORTED_CHECK_API_FILE = "support_wrap_ops.yaml"
93
+ CHECK_RESULT_FILE = "free_benchmark.csv"
94
+
73
95
  API_PREFIX_DICT = {
74
96
  "ops": Const.OPS_PREFIX,
75
- "Tensor": Const.Tensor_PREFIX,
97
+ "Tensor": Const.TENSOR_PREFIX,
76
98
  "mint": Const.MINT_PREFIX,
77
- "mint.nn.functional": Const.MINT_NN_FUNC_PREFIX,
78
- "communication": Const.COMM_PREFIX
99
+ "mint.nn.functional": Const.MINT_NN_FUNC_PREFIX
79
100
  }
80
101
 
81
102
  PERT_VALUE_DICT = {
@@ -86,6 +107,7 @@ class FreeBenchmarkConst:
86
107
  }
87
108
 
88
109
  ERROR_THRESHOLD = {
110
+ ms.bfloat16: 1.004,
89
111
  ms.float16: 1.002,
90
112
  ms.float32: 1.0002
91
113
  }
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,15 +12,10 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
- import os
17
- import time
18
- import sys
19
-
20
- from msprobe.mindspore.common.utils import get_rank_if_initialized
21
- from msprobe.core.common.log import BaseLogger
22
16
  from msprobe.core.common.exceptions import DistributedNotInitializedError
17
+ from msprobe.core.common.log import BaseLogger
18
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
23
19
 
24
20
 
25
21
  class MindsporeLogger(BaseLogger):
@@ -35,4 +31,4 @@ class MindsporeLogger(BaseLogger):
35
31
  return current_rank
36
32
 
37
33
 
38
- logger = MindsporeLogger()
34
+ logger = MindsporeLogger()
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,13 +12,19 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
+
15
16
  import os
17
+ import random
18
+
16
19
  import mindspore as ms
17
20
 
21
+ from mindspore import ops
22
+ from mindspore.mint import nn
18
23
  from msprobe.core.common.exceptions import DistributedNotInitializedError
19
24
  from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
20
25
  from msprobe.core.common.log import logger
26
+ from msprobe.core.common.const import Const
27
+ from msprobe.core.common.utils import CompareException, check_seed_all
21
28
 
22
29
 
23
30
  def get_rank_if_initialized():
@@ -36,7 +43,7 @@ def convert_bf16_to_fp32(tensor):
36
43
  def save_tensor_as_npy(tensor, file_path):
37
44
  if not path_len_exceeds_limit(file_path):
38
45
  tensor = convert_bf16_to_fp32(tensor)
39
- saved_tensor = tensor.asnumpy()
46
+ saved_tensor = tensor.contiguous().asnumpy()
40
47
  save_npy(saved_tensor, file_path)
41
48
  else:
42
49
  logger.warning(f'The file path {file_path} length exceeds limit.')
@@ -53,12 +60,15 @@ def list_lowest_level_directories(root_dir):
53
60
  check_path_exists(root_dir)
54
61
  lowest_level_dirs = []
55
62
 
56
- def recurse_dirs(current_dir):
63
+ def recurse_dirs(current_dir, depth=0):
64
+ if depth > Const.MAX_DEPTH:
65
+ logger.error(f'The directory {current_dir} has more than {Const.MAX_DEPTH} levels.')
66
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
57
67
  for entry in os.listdir(current_dir):
58
68
  full_path = os.path.join(current_dir, entry)
59
69
  if os.path.isdir(full_path):
60
70
  if any(os.path.isdir(os.path.join(full_path, subentry)) for subentry in os.listdir(full_path)):
61
- recurse_dirs(full_path)
71
+ recurse_dirs(full_path, depth=depth+1)
62
72
  else:
63
73
  lowest_level_dirs.append(full_path)
64
74
 
@@ -66,6 +76,16 @@ def list_lowest_level_directories(root_dir):
66
76
  return lowest_level_dirs
67
77
 
68
78
 
79
+ def seed_all(seed=1234, mode=False, rm_dropout=True):
80
+ check_seed_all(seed, mode)
81
+ os.environ['PYTHONHASHSEED'] = str(seed)
82
+ ms.set_seed(seed)
83
+ random.seed(seed)
84
+ ms.set_context(deterministic="ON" if mode else "OFF")
85
+ os.environ['HCCL_DETERMINISTIC'] = str(mode)
86
+ if rm_dropout:
87
+ remove_dropout()
88
+
69
89
 
70
90
  class MsprobeStep(ms.train.Callback):
71
91
 
@@ -79,3 +99,38 @@ class MsprobeStep(ms.train.Callback):
79
99
  def on_train_step_end(self, run_context):
80
100
  self.debugger.stop()
81
101
  self.debugger.step()
102
+
103
+
104
+ class Dropout(ops.Dropout):
105
+ def __init__(self, keep_prob=0.5, Seed0=0, Seed1=1):
106
+ super().__init__(1., Seed0, Seed1)
107
+
108
+
109
+ class Dropout2D(ops.Dropout2D):
110
+ def __init__(self, keep_prob=0.5):
111
+ super().__init__(1.)
112
+
113
+
114
+ class Dropout3D(ops.Dropout3D):
115
+ def __init__(self, keep_prob=0.5):
116
+ super().__init__(1.)
117
+
118
+
119
+ class DropoutExt(nn.Dropout):
120
+ def __init__(self, p=0.5):
121
+ super().__init__(0)
122
+
123
+
124
+ def dropout_ext(input_tensor, p=0.5, training=True):
125
+ return input_tensor
126
+
127
+
128
+ def remove_dropout():
129
+ ops.Dropout = Dropout
130
+ ops.operations.Dropout = Dropout
131
+ ops.Dropout2D = Dropout2D
132
+ ops.operations.Dropout2D = Dropout2D
133
+ ops.Dropout3D = Dropout3D
134
+ ops.operations.Dropout3D = Dropout3D
135
+ nn.Dropout = DropoutExt
136
+ nn.functional.dropout = dropout_ext
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,31 +12,29 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import os
18
- from msprobe.core.common.utils import CompareException, check_compare_param, \
19
- check_configuration_param, task_dumppath_get
17
+ from msprobe.core.common.utils import CompareException
20
18
  from msprobe.core.common.file_utils import create_directory
21
19
  from msprobe.core.common.exceptions import FileCheckException
22
20
  from msprobe.mindspore.common.log import logger
23
- from msprobe.mindspore.compare.ms_compare import MSComparator
21
+ from msprobe.mindspore.compare.ms_compare import ms_compare
24
22
  from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
25
23
  from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
26
24
 
25
+
27
26
  def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
28
27
  if kwargs.get('suffix'):
29
28
  logger.error("Argument 'suffix' is not supported for compare_distributed.")
30
29
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
31
- stack_mode = kwargs.get('stack_mode', False)
32
- auto_analyze = kwargs.get('auto_analyze', True)
33
- fuzzy_match = kwargs.get('fuzzy_match', False)
30
+ is_print_compare_log = kwargs.get('is_print_compare_log', True)
34
31
  # get the ranks and match by order
35
32
  npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
36
33
  bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
37
34
  if len(npu_ranks) != len(bench_ranks):
38
35
  logger.error('The number of ranks in the two runs are different. '
39
- 'Unable to match the ranks. Please use another folder to compare '
40
- 'or use compare() api and manually match the ranks.')
36
+ 'Unable to match the ranks. Please use another folder to compare '
37
+ 'or use compare() api and manually match the ranks.')
41
38
  raise CompareException(CompareException.INVALID_PATH_ERROR)
42
39
  for nr, br in zip(npu_ranks, bench_ranks):
43
40
  npu_data_dir = os.path.join(npu_dump_dir, nr)
@@ -50,19 +47,9 @@ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
50
47
  'npu_json_path': npu_path,
51
48
  'bench_json_path': bench_path,
52
49
  'stack_json_path': stack_path,
53
- 'is_print_compare_log': True
50
+ 'is_print_compare_log': is_print_compare_log
54
51
  }
55
- try:
56
- summary_compare, md5_compare = task_dumppath_get(dump_result_param)
57
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
58
- create_directory(output_path)
59
- check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
60
- except (CompareException, FileCheckException) as error:
61
- logger.error('Compare failed. Please check the arguments and do it again!')
62
- raise CompareException(error.code) from error
63
- ms_comparator = MSComparator()
64
- ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
65
- md5_compare=md5_compare, **kwargs)
52
+ ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
66
53
 
67
54
 
68
55
  def ms_graph_compare(inputs, outputs):
@@ -71,5 +58,5 @@ def ms_graph_compare(inputs, outputs):
71
58
  except (CompareException, FileCheckException) as error:
72
59
  logger.error('Compare failed. Please check the arguments and do it again!')
73
60
  return
74
- msComparator = GraphMSComparator(inputs, outputs)
75
- msComparator.compare_core()
61
+ ms_comparator = GraphMSComparator(inputs, outputs)
62
+ ms_comparator.compare_core()