mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -0,0 +1,145 @@
1
+ import os
2
+ import re
3
+ import argparse
4
+ from glob import glob
5
+
6
+ import pandas as pd
7
+
8
+ from msprobe.core.common.log import logger
9
+
10
+
11
+ def parse_logfile(logfile):
12
+ grad_norm = []
13
+ step = []
14
+ with open(logfile) as f:
15
+ for line in f.readlines():
16
+ if 'consumed samples' in line:
17
+ grad_norm.append(float(re.findall('(?<=grad norm\: )[\d\.]*', line)[0]))
18
+ return grad_norm
19
+
20
+
21
+ def parse_monitor_output(output_dir):
22
+ reduced = {}
23
+ unreduced = {}
24
+ for dir in glob(output_dir + '*'):
25
+ rank = int(re.findall('(?<=rank)[\d]*', dir)[0])
26
+ unreduced[rank] = []
27
+ reduced[rank] = []
28
+ for file in os.listdir(dir):
29
+ df = pd.read_csv(os.path.join(dir, file))
30
+ if '_unreduced_' in file:
31
+ unreduced[rank].append(df)
32
+ pass
33
+ elif '_reduced_' in file:
34
+ reduced[rank].append(df)
35
+ else:
36
+ logger.info(f'unexpected file {file} in {dir}')
37
+ return reduced, unreduced
38
+
39
+
40
+ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
41
+ steps = len(reduced[0])
42
+ world_size = len(reduced)
43
+ errors = []
44
+ for index, row in unreduced[0][0].iterrows():
45
+ param = row['param_name']
46
+ is_tp_duplicate = False
47
+ for step in range(2):
48
+ # sum reduced
49
+ reduced_mean = 0.
50
+ for rank in range(world_size):
51
+ if len(reduced[rank]) == 0:
52
+ continue
53
+ df = reduced[rank][step]
54
+ value = list(df[df['param_name'] == param]['mean'])
55
+ if not value:
56
+ if step == 0:
57
+ is_tp_duplicate = True
58
+ continue
59
+ reduced_mean += value[0]
60
+
61
+ # sum unreduced
62
+ unreduced_mean = 0.
63
+ for rank in range(world_size):
64
+ df = unreduced[rank][step]
65
+ value = list(df[df['param_name'] == param]['mean'])
66
+ if not value:
67
+ continue
68
+ unreduced_mean += list(df[df['param_name'] == param]['mean'])[0]
69
+
70
+ unreduced_mean /= dp_size
71
+ if is_tp_duplicate and (not sequence_parallel or 'embedding' in param):
72
+ unreduced_mean /= tp_size
73
+ try:
74
+ assert_equal(unreduced_mean, reduced_mean)
75
+ except AssertionError as e:
76
+ errors.append([param, step, e, is_tp_duplicate])
77
+ if errors:
78
+ logger.info(errors)
79
+ else:
80
+ logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitord.')
81
+
82
+
83
+ def assert_equal(a, b):
84
+ if b == 0 or a == 0:
85
+ return
86
+ if b == 0:
87
+ rel_diff = a
88
+ elif a == 0:
89
+ rel_diff = b
90
+ else:
91
+ rel_diff = abs(a / b - 1)
92
+ assert rel_diff < 0.01, f'{a}, {b}, {rel_diff}'
93
+
94
+
95
+ def valid_total_norm(total_norm, reduced, duplicate_embedding):
96
+ steps = len(total_norm)
97
+ world_size = len(reduced)
98
+ errors = []
99
+ for step in range(steps):
100
+ calculated_norm = 0.
101
+ for rank in range(world_size):
102
+ if len(reduced[rank]) == 0:
103
+ if step == 0:
104
+ logger.info(f'rank {rank} is duplicated in dp group')
105
+ continue
106
+ for index, row in reduced[rank][step].iterrows():
107
+ if duplicate_embedding and 'word_embedding' in row['param_name']:
108
+ continue
109
+ calculated_norm += row['norm'] ** 2
110
+ try:
111
+ assert_equal(calculated_norm ** 0.5, total_norm[step])
112
+ except AssertionError as e:
113
+ errors.append([step, e])
114
+ if errors:
115
+ logger.info('total norm errors: ', errors)
116
+ else:
117
+ logger.info('grad norm in consist between training log and reduced gradients monitored')
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument('--monitor_output', '-m', type=str, required=True,
123
+ help='path prefix to the output of monitor e.g. monitor_output/Aug12_07-16')
124
+ parser.add_argument('--logfile', '-l', type=str, required=True, help='path to the training log file')
125
+ parser.add_argument('--tp_size', '-t', type=int, required=True, help='tp parallel size')
126
+ parser.add_argument('--dp_size', '-d', type=int, required=True, help='dp parallel size')
127
+ parser.add_argument('--pp_size', '-p', type=int, required=True, help='pp parallel size')
128
+ parser.add_argument('--untie_embeddings_and_output_weights', '-u', action="store_true", default=False,
129
+ help='whether untie_embeddings_and_output_weights in pp parallel')
130
+ parser.add_argument('--sequence_parallel', '-s', action="store_true", default=False,
131
+ help='whether sequence parallel is enabled. Add -s to store true')
132
+
133
+ args = parser.parse_args()
134
+
135
+ assert args.tp_size > 0, 'if tp not enabled, set tp_size = 1'
136
+ assert args.dp_size > 0, 'if tp not enabled, set dp_size = 1'
137
+ assert args.pp_size > 0, 'if tp not enabled, set pp_size = 1'
138
+
139
+ total_norm = parse_logfile(args.logfile)
140
+ reduced, unreduced = parse_monitor_output(args.monitor_output)
141
+
142
+ duplicate_embedding = not args.untie_embeddings_and_output_weights and args.pp_size > 1
143
+
144
+ valid_total_norm(total_norm, reduced, duplicate_embedding)
145
+ valid_reduce(reduced, unreduced, args.tp_size, args.dp_size, args.sequence_parallel)
@@ -0,0 +1,250 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import inspect
16
+ from collections import namedtuple
17
+ from datetime import timezone, timedelta
18
+ from functools import wraps
19
+
20
+ import torch
21
+
22
+ from msprobe.core.common.const import MonitorConst, Const
23
+ from msprobe.core.common.log import logger
24
+ from msprobe.core.common.utils import is_int
25
+
26
+ FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
27
+ FILE_NAME_MAX_LENGTH = 255
28
+ DIRECTORY_MAX_LENGTH = 4096
29
+
30
+ beijing_tz = timezone(timedelta(hours=8))
31
+ MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
32
+ MVGradResult = namedtuple('MVGradResult', ("exp_avg", "exp_avg_sq", "update", "ratio", "grad"))
33
+
34
+
35
+ class MsgConst:
36
+ """
37
+ Class for log messages const
38
+ """
39
+ SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
40
+
41
+
42
+ def filter_special_chars(func):
43
+ @wraps(func)
44
+ def func_level(msg):
45
+ for char in MsgConst.SPECIAL_CHAR:
46
+ msg = msg.replace(char, '_')
47
+ return func(msg)
48
+
49
+ return func_level
50
+
51
+
52
+ def get_param_struct(param):
53
+ res = {}
54
+ if isinstance(param, (tuple, list)):
55
+ res['config'] = f'{type(param).__name__}[{len(param)}]'
56
+ for i, x in enumerate(param):
57
+ res[i] = f'size={tuple(x.shape)}, dtype={x.dtype}' if torch.is_tensor(x) else f'{type(x)}'
58
+ elif torch.is_tensor(param):
59
+ res['config'] = 'tensor'
60
+ res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}'
61
+ else:
62
+ res['config'] = f'{type(param)}'
63
+ logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}')
64
+ return res
65
+
66
+
67
+ def is_recomputation():
68
+ """Check if the current operation is in the re-computation phase.
69
+
70
+ This function inspects the current call stack to indicate whether the current operation is in the
71
+ re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
72
+ megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
73
+ mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
74
+ file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the
75
+ 'torch/_tensor.py' file.
76
+
77
+ Returns:
78
+ bool: True if in the re-computation phase, False otherwise.
79
+ """
80
+ backward_function_indices = []
81
+ call_stack = inspect.stack()
82
+
83
+ # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
84
+ for frame_info in call_stack:
85
+ if frame_info.function == Const.BACKWARD and frame_info.filename.endswith('torch/_tensor.py'):
86
+ del call_stack
87
+ return True
88
+
89
+ # Identify indices in the call stack where the specific function is being executed
90
+ for idx, frame_info in enumerate(call_stack):
91
+ if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
92
+ backward_function_indices.append(idx)
93
+
94
+ # Check if the execution is within 'torch/autograd/function.py' file
95
+ for idx in backward_function_indices:
96
+ # The Megatron and MindSpeed L0&L1 scenes
97
+ if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
98
+ del call_stack
99
+ return True
100
+ # The latest MindSpeed L2 and ModelLink scenes
101
+ if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
102
+ del call_stack
103
+ return True
104
+
105
+ del call_stack
106
+ return False
107
+
108
+
109
+ def validate_ops(ops):
110
+ if not isinstance(ops, list):
111
+ raise TypeError("ops should be a list")
112
+ if not ops:
113
+ raise TypeError(f"specify ops to calculate metrics. Optional ops: {MonitorConst.OP_LIST}")
114
+
115
+ valid_ops = []
116
+ for op in ops:
117
+ if op not in MonitorConst.OP_LIST:
118
+ logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
119
+ else:
120
+ valid_ops.append(op)
121
+ return valid_ops
122
+
123
+
124
+ def validate_ranks(ranks):
125
+ if not isinstance(ranks, list):
126
+ raise TypeError("module_ranks should be a list")
127
+ for rank in ranks:
128
+ if not isinstance(rank, int) or isinstance(rank, bool):
129
+ raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
130
+
131
+
132
+ def validate_targets(targets):
133
+ if not isinstance(targets, dict):
134
+ raise TypeError('targets in config.json should be a dict')
135
+ for module_name, field in targets.items():
136
+ if not isinstance(module_name, str):
137
+ raise TypeError('key of targets should be module_name[str] in config.json')
138
+ if not isinstance(field, dict):
139
+ raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
140
+
141
+
142
+ def validate_print_struct(print_struct):
143
+ if not isinstance(print_struct, bool):
144
+ raise TypeError("print_struct should be a bool")
145
+
146
+
147
+ def validate_ur_distribution(ur_distribution):
148
+ if not isinstance(ur_distribution, bool):
149
+ raise TypeError('ur_distribution should be a bool')
150
+
151
+
152
+ def validate_xy_distribution(xy_distribution):
153
+ if not isinstance(xy_distribution, bool):
154
+ raise TypeError('xy_distribution should be a bool')
155
+
156
+
157
+ def validate_wg_distribution(wg_distribution):
158
+ if not isinstance(wg_distribution, bool):
159
+ raise TypeError('wg_distribution should be a bool')
160
+
161
+
162
+ def validate_mg_distribution(mg_distribution):
163
+ if not isinstance(mg_distribution, bool):
164
+ raise TypeError('mg_distribution should be a bool')
165
+
166
+
167
+ def validate_cc_distribution(cc_distribution):
168
+ if not isinstance(cc_distribution, dict):
169
+ raise TypeError('cc_distribution should be a dictionary')
170
+ for key, value in cc_distribution.items():
171
+ if key == 'enable':
172
+ if not isinstance(value, bool):
173
+ raise TypeError('cc_distribution enable should be a bool')
174
+ elif key == 'cc_codeline':
175
+ if not isinstance(value, list):
176
+ raise TypeError('cc_distribution cc_codeline should be a list')
177
+ elif key == 'cc_pre_hook':
178
+ if not isinstance(value, bool):
179
+ raise TypeError('cc_distribution cc_pre_hook should be a bool')
180
+ elif key == 'cc_log_only':
181
+ if not isinstance(value, bool):
182
+ raise TypeError('cc_distribution cc_log_only should be a bool')
183
+ else:
184
+ raise TypeError(f'{key} of cc_distribution is not supported.')
185
+
186
+
187
+ def validate_alert(alert):
188
+ if not isinstance(alert, dict):
189
+ raise TypeError('alert should be a dictionary')
190
+ rules = alert.get('rules')
191
+ if rules and isinstance(rules, list):
192
+ for rule in rules:
193
+ rule_name = rule.get("rule_name")
194
+ if rule_name and rule_name not in MonitorConst.RULE_NAME:
195
+ raise TypeError(f"{rule_name} is not supported")
196
+ args = rule.get("args")
197
+ if args and isinstance(args, dict):
198
+ threshold = args.get("threshold")
199
+ if not isinstance(threshold, float) or threshold < 0:
200
+ raise TypeError('threshold must be float and not less than 0')
201
+ dump = alert.get('dump')
202
+ if dump and not isinstance(dump, bool):
203
+ raise TypeError('dump must be bool.')
204
+
205
+
206
+ def validate_step_count_per_record(step_count_per_record):
207
+ if not is_int(step_count_per_record):
208
+ raise TypeError('step_count_per_record must be int.')
209
+ if step_count_per_record < 1:
210
+ raise ValueError("step_count_per_record must greater than 0")
211
+ if step_count_per_record > 1e6:
212
+ raise ValueError("step_count_per_record must smaller than 1e6")
213
+
214
+
215
+ def validate_config(config):
216
+ config['ops'] = validate_ops(config.get('ops', []))
217
+
218
+ eps = config.get('eps', 1e-8)
219
+ if not isinstance(eps, float):
220
+ raise TypeError("eps should be a float")
221
+
222
+ ranks = config.get("module_ranks", [])
223
+ validate_ranks(ranks)
224
+
225
+ targets = config.get("targets", {})
226
+ validate_targets(targets)
227
+
228
+ print_struct = config.get('print_struct', False)
229
+ validate_print_struct(print_struct)
230
+
231
+ ur_distribution = config.get('ur_distribution', False)
232
+ validate_ur_distribution(ur_distribution)
233
+
234
+ xy_distribution = config.get('xy_distribution', False)
235
+ validate_xy_distribution(xy_distribution)
236
+
237
+ wg_distribution = config.get('wg_distribution', False)
238
+ validate_wg_distribution(wg_distribution)
239
+
240
+ mg_distribution = config.get('mg_distribution', False)
241
+ validate_mg_distribution(mg_distribution)
242
+
243
+ cc_distribution = config.get('cc_distribution', {})
244
+ validate_cc_distribution(cc_distribution)
245
+
246
+ alert = config.get('alert', {})
247
+ validate_alert(alert)
248
+
249
+ step_count_per_record = config.get('step_count_per_record', 1)
250
+ validate_step_count_per_record(step_count_per_record)
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from msprobe.pytorch.monitor.features import cal_histc
20
+
21
+
22
+ class HeatmapVisualizer:
23
+ def __init__(self) -> None:
24
+ self.histogram_bins_num = 30
25
+ self.min_val = -1
26
+ self.max_val = 1
27
+ self.histogram_edges = None
28
+ self.histogram_sum_data_np = None # matrix shape is [bins_num * total_step]
29
+ self.cur_step_histogram_data = None
30
+ self.histogram_edges = torch.linspace(self.min_val, self.max_val, self.histogram_bins_num)
31
+
32
+ def pre_cal(self, tensor):
33
+ self.cur_step_histogram_data = cal_histc(tensor_cal=tensor, bins_total=self.histogram_bins_num,
34
+ min_val=self.min_val, max_val=self.max_val)
35
+
36
+ def visualize(self, tag_name: str, step, summary_writer):
37
+ if self.histogram_sum_data_np is None or self.histogram_sum_data_np.size == 0:
38
+ self.histogram_sum_data_np = np.expand_dims(self.cur_step_histogram_data.cpu(), 0).T
39
+ else:
40
+ # add new data along a different axis because we transposed early
41
+ # matrix shape is [bins_num * total_step]
42
+ self.histogram_sum_data_np = np.concatenate((self.histogram_sum_data_np, np.expand_dims(
43
+ self.cur_step_histogram_data.cpu(), 1)), axis=1)
44
+
45
+ fig, ax = plt.subplots()
46
+ cax = ax.matshow(self.histogram_sum_data_np, cmap='hot', aspect='auto')
47
+ fig.colorbar(cax)
48
+
49
+ lbs = [f'{self.histogram_edges[i]:.2f}' for i in range(self.histogram_bins_num)]
50
+ plt.yticks(ticks=range(self.histogram_bins_num), labels=lbs)
51
+ ax.set_xlabel('Step')
52
+ ax.set_ylabel('Value Range')
53
+ plt.title(f'Total Step: {step}')
54
+
55
+ # Convert matplotlib figure to an image format suitable for TensorBoard
56
+ fig.canvas.draw()
57
+ image = torch.from_numpy(np.array(fig.canvas.renderer.buffer_rgba()))
58
+ plt.close(fig)
59
+ summary_writer.add_image(tag_name, image.permute(2, 0, 1), global_step=step, dataformats='CHW')
@@ -12,9 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ __all__ = ["PtdbgDispatch"]
16
+
15
17
  from signal import signal, SIGPIPE, SIG_DFL
16
18
  from .dispatch import PtdbgDispatch
17
19
  signal(SIGPIPE, SIG_DFL)
18
-
19
-
20
- __all__ = ["PtdbgDispatch"]
@@ -1,16 +1,30 @@
1
- # 进行比对及结果展示
2
- import os
3
- import sys
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
4
16
  import csv
5
17
  import json
18
+ import os
19
+ import sys
6
20
  from collections import namedtuple
7
- from rich.table import Table
8
- from rich.console import Console
21
+
9
22
  from msprobe.core.common.const import CompareConst, FileCheckConst
10
- from msprobe.core.common.file_utils import FileOpen, change_mode, read_csv
23
+ from msprobe.core.common.file_utils import read_csv, get_json_contents, write_csv
24
+ from msprobe.core.common.utils import check_op_str_pattern_valid
11
25
  from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap
12
- from msprobe.pytorch.common.log import logger
13
- from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
26
+ from rich.console import Console
27
+ from rich.table import Table
14
28
 
15
29
  ELEMENT_NUM_THRESHOLD = 100
16
30
  ZERO_NUM_THRESHOLD = 0.1
@@ -19,30 +33,6 @@ FLOAT_PRECISION = 14
19
33
  ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_success',
20
34
  'fwd_compare_alg_results', 'bwd_compare_alg_results'])
21
35
 
22
- def get_file_content_bytes(file):
23
- with FileOpen(file, 'rb') as file_handle:
24
- return file_handle.read()
25
-
26
-
27
- def get_json_contents(file_path):
28
- ops = get_file_content_bytes(file_path)
29
- try:
30
- json_obj = json.loads(ops)
31
- except ValueError as error:
32
- logger.error('Failed to load "%s". %s' % (file_path, str(error)))
33
- raise CompareException(CompareException.INVALID_FILE_ERROR) from error
34
- if not isinstance(json_obj, dict):
35
- logger.error('Json file %s, content is not a dictionary!' % file_path)
36
- raise CompareException(CompareException.INVALID_FILE_ERROR)
37
- return json_obj
38
-
39
-
40
- def write_csv(data, filepath):
41
- with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
42
- writer = csv.writer(f)
43
- writer.writerows(data)
44
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
45
-
46
36
 
47
37
  class Saver:
48
38
  # consts for result csv
@@ -62,14 +52,15 @@ class Saver:
62
52
  }
63
53
 
64
54
  def write_csv_title(self):
65
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
55
+ summary_test_rows = [
56
+ [self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
66
57
  write_csv(summary_test_rows, self.save_path)
67
58
 
68
59
  detail_test_rows = [[
69
60
  "Npu Name", "Bench Dtype", "NPU Dtype", "Shape",
70
61
  "error_balance", "max_abs_diff", "max_abs_idx",
71
62
  "max_rel_diff", "max_rel_idx", "eb_thd",
72
- "error_thd", "Status","Message"
63
+ "error_thd", "Status", "Message"
73
64
  ]]
74
65
  write_csv(detail_test_rows, self.detail_save_path)
75
66
 
@@ -106,7 +97,7 @@ class Saver:
106
97
  console.print(table_detail)
107
98
 
108
99
  def get_statistics_from_result_csv(self):
109
- checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
100
+ checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.N_A, CompareConst.SKIP]
110
101
  data = read_csv(self.save_path)
111
102
  result_csv_name = os.path.basename(self.save_path)
112
103
  for _, row in data.iterrows():
@@ -121,7 +112,7 @@ class Saver:
121
112
  if column1 == CompareConst.SKIP:
122
113
  continue
123
114
  self.test_result_cnt["total_num"] += 1
124
- if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']:
115
+ if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, CompareConst.N_A]:
125
116
  self.test_result_cnt['success_num'] += 1
126
117
  elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE:
127
118
  self.test_result_cnt['forward_and_backward_fail_num'] += 1
@@ -228,8 +219,8 @@ class Comparator:
228
219
  is_bwd_success, bwd_compare_alg_results = True, None
229
220
  if is_bwd_success and bwd_compare_alg_results is None:
230
221
  self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results,
231
- bwd_compare_alg_results))
222
+ bwd_compare_alg_results))
232
223
  else:
233
224
  self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
234
- bwd_compare_alg_results))
225
+ bwd_compare_alg_results))
235
226
  return is_fwd_success, is_bwd_success