mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -22,6 +22,7 @@ from msprobe.core.config_check.config_checker import register_checker_item, regi
22
22
  from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features
23
23
  from msprobe.core.common.decorator import recursion_depth_decorator
24
24
  from msprobe.core.common.framework_adapter import FmkAdp
25
+ from msprobe.core.common.const import Const
25
26
 
26
27
 
27
28
  @recursion_depth_decorator("config_check: process_obj")
@@ -134,5 +135,5 @@ class DatasetChecker(BaseChecker):
134
135
  cmp_dataset_pack_path = os.path.join(cmp_dir, DatasetChecker.target_name_in_zip)
135
136
 
136
137
  df = compare_dataset(bench_dataset_pack_path, cmp_dataset_pack_path)
137
- pass_check = False not in df['equal'].values
138
+ pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR
138
139
  return DatasetChecker.target_name_in_zip, pass_check, df
@@ -21,7 +21,7 @@ import pandas as pd
21
21
  from msprobe.core.common.file_utils import load_json, load_yaml, create_file_with_content, create_file_in_zip
22
22
  from msprobe.core.config_check.checkers.base_checker import BaseChecker
23
23
  from msprobe.core.config_check.config_checker import register_checker_item
24
- from msprobe.core.config_check.utils.utils import config_checking_print
24
+ from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check
25
25
  from msprobe.core.common.const import Const
26
26
 
27
27
 
@@ -59,17 +59,17 @@ def compare_env_data(npu_path, bench_path):
59
59
  cmp_env_name = cmp_env["name"]
60
60
  cmp_value = cmp_data.get(cmp_env_name, value[cmp_type]["default_value"])
61
61
  if not bench_env:
62
- data.append(["only cmp has this env", cmp_env["name"], "", cmp_value, "warning"])
62
+ data.append(["only cmp has this env", cmp_env["name"], "", cmp_value, Const.CONFIG_CHECK_WARNING])
63
63
  continue
64
64
  bench_env_name = bench_env["name"]
65
65
  bench_value = bench_data.get(bench_env_name, value[bench_type]["default_value"])
66
66
  if cmp_value != bench_value:
67
- data.append([bench_env_name, cmp_env_name, bench_value, cmp_value, "error"])
67
+ data.append([bench_env_name, cmp_env_name, bench_value, cmp_value, Const.CONFIG_CHECK_ERROR])
68
68
  else:
69
69
  bench_env_name = bench_env["name"]
70
70
  bench_value = bench_data.get(bench_env_name) if bench_data.get(bench_env_name) else value[bench_type][
71
71
  "default_value"]
72
- data.append([bench_env_name, "only bench has this env", bench_value, "", "warning"])
72
+ data.append([bench_env_name, "only bench has this env", bench_value, "", Const.CONFIG_CHECK_WARNING])
73
73
  df = pd.DataFrame(data, columns=EnvArgsChecker.result_header)
74
74
  return df
75
75
 
@@ -92,5 +92,5 @@ class EnvArgsChecker(BaseChecker):
92
92
  bench_env_data = os.path.join(bench_dir, EnvArgsChecker.target_name_in_zip)
93
93
  cmp_env_data = os.path.join(cmp_dir, EnvArgsChecker.target_name_in_zip)
94
94
  df = compare_env_data(bench_env_data, cmp_env_data)
95
- pass_check = "error" not in df['level'].values
95
+ pass_check = process_pass_check(df['level'].values)
96
96
  return EnvArgsChecker.target_name_in_zip, pass_check, df
@@ -23,7 +23,7 @@ import pandas as pd
23
23
  from msprobe.core.common.utils import check_extern_input_list
24
24
  from msprobe.core.config_check.checkers.base_checker import BaseChecker
25
25
  from msprobe.core.config_check.config_checker import register_checker_item
26
- from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict
26
+ from msprobe.core.config_check.utils.utils import compare_dict, config_checking_print, update_dict, process_pass_check
27
27
  from msprobe.core.config_check.utils.hyperparameter_parser import ParserFactory
28
28
  from msprobe.core.common.file_utils import (check_file_or_directory_path, create_file_in_zip, load_json,
29
29
  load_yaml)
@@ -36,6 +36,20 @@ parameter_name_mapping = load_yaml(os.path.realpath(hyperparameters_path))
36
36
  hyperparameters_dict = {}
37
37
 
38
38
 
39
+ def refine_json_keys(json_dcit):
40
+ new_dict = {}
41
+ for key in json_dcit.keys():
42
+ new_key = key.split(Const.SEP)[-1].replace("-", "_")
43
+ new_dict[new_key] = key
44
+ return new_dict
45
+
46
+
47
+ def to_str_if_number(value):
48
+ if isinstance(value, (int, float)):
49
+ return str(value)
50
+ return value
51
+
52
+
39
53
  @register_checker_item("hyperparameter")
40
54
  class HyperparameterChecker(BaseChecker):
41
55
  target_name_in_zip = "hyperparameters"
@@ -86,29 +100,35 @@ class HyperparameterChecker(BaseChecker):
86
100
  all_diffs.extend(
87
101
  HyperparameterChecker.compare_param(bench_hyperparameters, cmp_hyperparameters, file_name))
88
102
  df = pd.DataFrame(all_diffs, columns=HyperparameterChecker.result_header)
89
- pass_check = "error" not in df["level"].values
103
+ pass_check = process_pass_check(df["level"].values)
90
104
  return HyperparameterChecker.target_name_in_zip, pass_check, df
91
105
 
92
106
  @staticmethod
93
107
  def compare_param(bench_params, cmp_params, file_name):
94
108
  all_diffs = []
95
- bench_param_names = bench_params.keys()
96
- for bench_param_name in bench_param_names:
109
+ bench_params_refined = refine_json_keys(bench_params)
110
+ cmp_params_refined = refine_json_keys(cmp_params)
111
+
112
+ for bench_param_name in bench_params_refined.keys():
97
113
  matched_cmp_param_name, matched_with = HyperparameterChecker._fuzzy_match_parameter(bench_param_name,
98
- cmp_params)
99
- bench_param_value = bench_params[bench_param_name]
114
+ cmp_params_refined)
115
+ matched_cmp_param_name = cmp_params_refined.get(matched_cmp_param_name)
116
+ bench_param_name = bench_params_refined.get(bench_param_name)
117
+ bench_param_value = to_str_if_number(bench_params[bench_param_name])
100
118
  if matched_cmp_param_name:
101
- cmp_param_value = cmp_params[matched_cmp_param_name]
119
+ cmp_param_value = to_str_if_number(cmp_params[matched_cmp_param_name])
102
120
  if bench_param_value != cmp_param_value:
103
121
  all_diffs.append(
104
122
  [file_name, bench_param_name, matched_cmp_param_name, bench_param_value, cmp_param_value,
105
- matched_with, "error"])
123
+ matched_with, Const.CONFIG_CHECK_ERROR])
106
124
  del cmp_params[matched_cmp_param_name]
107
125
  else:
108
126
  all_diffs.append(
109
- [file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "", "warning"])
127
+ [file_name, bench_param_name, "Only in benchmark", bench_param_value, "", "",
128
+ Const.CONFIG_CHECK_WARNING])
110
129
  for cmp_param_name, cmp_param_value in cmp_params.items():
111
- all_diffs.append([file_name, "Only in comparison", cmp_param_name, "", cmp_param_value, "", "warning"])
130
+ all_diffs.append(
131
+ [file_name, "Only in comparison", cmp_param_name, "", cmp_param_value, "", Const.CONFIG_CHECK_WARNING])
112
132
  all_diffs.sort()
113
133
  return all_diffs
114
134
 
@@ -23,8 +23,9 @@ except ImportError:
23
23
  from msprobe.core.common.file_utils import load_yaml, create_file_in_zip
24
24
  from msprobe.core.config_check.checkers.base_checker import BaseChecker
25
25
  from msprobe.core.config_check.config_checker import register_checker_item
26
- from msprobe.core.config_check.utils.utils import config_checking_print
26
+ from msprobe.core.config_check.utils.utils import config_checking_print, process_pass_check
27
27
  from msprobe.core.common.file_utils import FileOpen, save_excel
28
+ from msprobe.core.common.const import Const
28
29
 
29
30
  dirpath = os.path.dirname(__file__)
30
31
  depend_path = os.path.join(dirpath, "../resource/dependency.yaml")
@@ -62,7 +63,7 @@ def compare_pip_data(bench_pip_path, cmp_pip_path, fmk):
62
63
  if bench_version != cmp_version:
63
64
  data.append([package, bench_version if bench_version else 'None',
64
65
  cmp_version if cmp_version else 'None',
65
- "error"])
66
+ Const.CONFIG_CHECK_ERROR])
66
67
 
67
68
  df = pd.DataFrame(data, columns=PipPackageChecker.result_header)
68
69
  return df
@@ -86,5 +87,5 @@ class PipPackageChecker(BaseChecker):
86
87
  bench_pip_path = os.path.join(bench_dir, PipPackageChecker.target_name_in_zip)
87
88
  cmp_pip_path = os.path.join(cmp_dir, PipPackageChecker.target_name_in_zip)
88
89
  df = compare_pip_data(bench_pip_path, cmp_pip_path, fmk)
89
- pass_check = "error" not in df['level'].values
90
+ pass_check = process_pass_check(df['level'].values)
90
91
  return PipPackageChecker.target_name_in_zip, pass_check, df
@@ -280,9 +280,9 @@ def mindspore_patchs():
280
280
  import mindspore
281
281
 
282
282
  mindspore_ops_patches = {
283
- 'rand': mindspore.ops.uniform,
283
+ 'rand': mindspore.ops.rand,
284
284
  'randint': mindspore.ops.randint,
285
- 'randn': mindspore.ops.normal
285
+ 'randn': mindspore.ops.randn
286
286
  }
287
287
  for name, func in mindspore_ops_patches.items():
288
288
  setattr(mindspore.ops, name, track_random_call(func, f"mindspore.ops.{name}"))
@@ -331,7 +331,7 @@ class RandomChecker(BaseChecker):
331
331
  cmp_stats_path = os.path.join(cmp_dir, RandomChecker.target_name_in_zip)
332
332
 
333
333
  df = compare_random_calls(bench_stats_path, cmp_stats_path)
334
- pass_check = False not in df['check_result'].values
334
+ pass_check = Const.CONFIG_CHECK_PASS if False not in df['check_result'].values else Const.CONFIG_CHECK_ERROR
335
335
 
336
336
  return RandomChecker.target_name_in_zip, pass_check, df
337
337
 
@@ -22,6 +22,7 @@ from msprobe.core.config_check.checkers.base_checker import BaseChecker
22
22
  from msprobe.core.config_check.config_checker import register_checker_item, register_pre_forward_fun_list
23
23
  from msprobe.core.config_check.utils.utils import config_checking_print, get_tensor_features
24
24
  from msprobe.core.common.framework_adapter import FmkAdp
25
+ from msprobe.core.common.const import Const
25
26
 
26
27
 
27
28
  def collect_weights_data(model):
@@ -143,5 +144,5 @@ class WeightsChecker(BaseChecker):
143
144
  bench_weight_pack_path = os.path.join(bench_dir, WeightsChecker.target_name_in_zip)
144
145
  cmp_weight_pack_path = os.path.join(cmp_dir, WeightsChecker.target_name_in_zip)
145
146
  df = compare_weight(bench_weight_pack_path, cmp_weight_pack_path)
146
- pass_check = False not in df['equal'].values
147
+ pass_check = Const.CONFIG_CHECK_PASS if False not in df['equal'].values else Const.CONFIG_CHECK_ERROR
147
148
  return WeightsChecker.target_name_in_zip, pass_check, df
@@ -138,6 +138,8 @@ def _consolidate_tp_weights(weights: Dict) -> Dict:
138
138
  def _parse_num_layers_per_stage(tp_partition):
139
139
  match = [re.findall(LAYER_IDX_PATTERN, key) for key in tp_partition.keys()]
140
140
  layer_idx = [int(i[0]) for i in match if i]
141
+ if not layer_idx:
142
+ return 1
141
143
  num_layers_per_pipeline_stage = max(layer_idx) + 1
142
144
 
143
145
  return num_layers_per_pipeline_stage
@@ -18,4 +18,14 @@ weight_decay:
18
18
 
19
19
  dropout_rate:
20
20
  - dropout
21
- - drop_rate
21
+ - drop_rate
22
+
23
+ compute_dtype:
24
+ - bf16
25
+ - fp32
26
+
27
+ residual_dtype:
28
+ - fp32_residual_connection
29
+
30
+ softmax_compute_dtype:
31
+ - attention_softmax_in_fp32
@@ -96,9 +96,13 @@ class YamlParser(Parser):
96
96
  new_prefix = prefix + Const.SEP + key if prefix else key
97
97
  self.recursive_parse_parameters(value, new_prefix)
98
98
  elif isinstance(parameters, list):
99
- for value in parameters:
100
- self.recursive_parse_parameters(value, prefix)
101
- elif isinstance(parameters, (int, str, bool)):
99
+ if all(isinstance(x, (int, float, str, bool, list))for x in parameters):
100
+ self.hyperparameters.update({prefix: parameters})
101
+ else:
102
+ for idx, value in enumerate(parameters):
103
+ new_prefix = prefix + Const.SEP + str(idx) if prefix else str(idx)
104
+ self.recursive_parse_parameters(value, new_prefix)
105
+ elif isinstance(parameters, (int, float, str, bool)):
102
106
  self.hyperparameters.update({prefix: parameters})
103
107
 
104
108
 
@@ -19,6 +19,7 @@ import hashlib
19
19
 
20
20
  from msprobe.core.common.framework_adapter import FmkAdp
21
21
  from msprobe.core.common.log import logger
22
+ from msprobe.core.common.const import Const
22
23
 
23
24
 
24
25
  def merge_keys(dir_0, dir_1):
@@ -105,3 +106,12 @@ def update_dict(ori_dict, new_dict):
105
106
  ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]}
106
107
  else:
107
108
  ori_dict[key] = value
109
+
110
+
111
+ def process_pass_check(data):
112
+ if Const.CONFIG_CHECK_ERROR in data:
113
+ return Const.CONFIG_CHECK_ERROR
114
+ elif Const.CONFIG_CHECK_WARNING in data:
115
+ return Const.CONFIG_CHECK_WARNING
116
+ else:
117
+ return Const.CONFIG_CHECK_PASS
@@ -35,7 +35,7 @@ class ApiWrapper:
35
35
  def __init__(
36
36
  self, api_types: Dict[str, Dict[str, Any]],
37
37
  api_list_paths: Union[str, List[str], Tuple[str]],
38
- backlist: Union[List[str], Tuple[str]] = None
38
+ blacklist: Union[List[str], Tuple[str]] = None
39
39
  ):
40
40
  self.api_types = api_types
41
41
  if not isinstance(api_list_paths, (list, tuple)):
@@ -44,7 +44,7 @@ class ApiWrapper:
44
44
  raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
45
45
  "when api_list_paths is a list or tuple.")
46
46
  self.api_list_paths = api_list_paths
47
- self.backlist = backlist if backlist else []
47
+ self.blacklist = blacklist if blacklist else []
48
48
  self.api_names = self._get_api_names()
49
49
  self.wrapped_api_functions = dict()
50
50
 
@@ -80,6 +80,26 @@ class ApiWrapper:
80
80
 
81
81
  return True, args, kwargs
82
82
 
83
+ def wrap_api_func(self, api_name, api_func, prefix, hook_build_func, api_template):
84
+ api_instance = api_template(api_name, api_func, prefix, hook_build_func)
85
+
86
+ def api_function(*args, **kwargs):
87
+ api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1])
88
+ enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, api_func, args, kwargs)
89
+ if not enable_wrap:
90
+ logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. '
91
+ 'It may be fixed by passing the value of "self" '
92
+ 'as a positional argument instead of a keyword argument. ')
93
+ return api_func(*args, **kwargs)
94
+ return api_instance(*args, **kwargs)
95
+
96
+ for attr_name in Const.API_ATTR_LIST:
97
+ if hasattr(api_func, attr_name):
98
+ attr_value = getattr(api_func, attr_name)
99
+ setattr(api_function, attr_name, attr_value)
100
+
101
+ return api_function
102
+
83
103
  def wrap_api(
84
104
  self, api_templates, hook_build_func: Optional[Callable]
85
105
  ):
@@ -100,23 +120,17 @@ class ApiWrapper:
100
120
  api_template = api_templates[index]
101
121
  index += 1
102
122
  for api_name in self.api_names.get(framework, {}).get(api_type, []):
103
- ori_api = _get_attr(api_modules[0], api_name)
123
+ ori_api = None
124
+ for module in api_modules[0]:
125
+ ori_api = ori_api or _get_attr(module, api_name)
104
126
  if callable(ori_api):
105
- def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
106
- def api_function(*args, **kwargs):
107
- api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1])
108
- enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix,
109
- api_func, args, kwargs)
110
- if not enable_wrap:
111
- logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. '
112
- 'It may be fixed by passing the value of "self" '
113
- 'as a positional argument instead of a keyword argument. ')
114
- return api_func(*args, **kwargs)
115
- return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
116
- api_function.__name__ = api_name
117
- return api_function
118
- wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
119
- hook_build_func, api_template)
127
+ wrapped_functions[api_name] = self.wrap_api_func(
128
+ api_name,
129
+ ori_api,
130
+ name_prefix,
131
+ hook_build_func,
132
+ api_template
133
+ )
120
134
  wrapped_functions_in_framework[api_type] = wrapped_functions
121
135
  self.wrapped_api_functions[framework] = wrapped_functions_in_framework
122
136
  return self.wrapped_api_functions
@@ -132,15 +146,17 @@ class ApiWrapper:
132
146
  api_from_file = api_list.get(key_in_file, [])
133
147
  names = set()
134
148
  for api_name in api_from_file:
135
- if f'{key_in_file}.{api_name}' in self.backlist:
149
+ if f'{key_in_file}.{api_name}' in self.blacklist:
136
150
  continue
137
151
  target_attr = api_name
138
- target_module = api_modules[0]
139
- if Const.SEP in api_name:
140
- sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
141
- target_module = getattr(api_modules[0], sub_module_name, None)
142
- if target_module and target_attr in dir(target_module):
143
- names.add(api_name)
152
+ for module in api_modules[0]:
153
+ if Const.SEP in api_name:
154
+ sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
155
+ target_module = getattr(module, sub_module_name, None)
156
+ else:
157
+ target_module = module
158
+ if target_module and target_attr in dir(target_module):
159
+ names.add(api_name)
144
160
  valid_names[api_type] = names
145
161
  api_names[framework] = valid_names
146
162
 
@@ -152,7 +168,7 @@ class ApiRegistry:
152
168
  Base class for api registry.
153
169
  """
154
170
 
155
- def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None):
171
+ def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, blacklist=None):
156
172
  self.ori_api_attr = dict()
157
173
  self.wrapped_api_attr = dict()
158
174
  self.inner_used_ori_attr = dict()
@@ -161,13 +177,16 @@ class ApiRegistry:
161
177
  self.inner_used_api = inner_used_api
162
178
  self.supported_api_list_path = supported_api_list_path
163
179
  self.api_templates = api_templates
164
- self.backlist = backlist if backlist else []
180
+ self.blacklist = blacklist if blacklist else []
165
181
  self.all_api_registered = False
166
182
 
167
183
  @staticmethod
168
- def store_ori_attr(ori_api_group, api_list, api_ori_attr):
184
+ def store_ori_attr(ori_api_groups, api_list, api_ori_attr):
169
185
  for api in api_list:
170
- api_ori_attr[api] = _get_attr(ori_api_group, api)
186
+ ori_api = None
187
+ for ori_api_group in ori_api_groups:
188
+ ori_api = ori_api or _get_attr(ori_api_group, api)
189
+ api_ori_attr[api] = ori_api
171
190
 
172
191
  @staticmethod
173
192
  def set_api_attr(api_group, attr_dict):
@@ -217,7 +236,7 @@ class ApiRegistry:
217
236
  self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
218
237
 
219
238
  def initialize_hook(self, hook_build_func):
220
- api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist)
239
+ api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.blacklist)
221
240
  wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
222
241
 
223
242
  for framework, api_types in self.api_types.items():
@@ -23,6 +23,7 @@ from msprobe.core.data_dump.json_writer import DataWriter
23
23
  from msprobe.core.common.log import logger
24
24
  from msprobe.core.common.const import Const
25
25
  from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
26
+ from msprobe.core.common.megatron_utils import MegatronStepInfo, get_micro_step, is_megatron
26
27
 
27
28
 
28
29
  def build_data_collector(config):
@@ -41,6 +42,7 @@ class DataCollector:
41
42
  self.module_count = {}
42
43
  self.scope = ScopeFactory(self.config).build_scope()
43
44
  self.backward_module_names = {}
45
+ self.params_grad_record = {}
44
46
  self.optimizer_status = ""
45
47
  self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
46
48
  atexit.register(self.write_json_at_exit)
@@ -118,12 +120,16 @@ class DataCollector:
118
120
  self.set_is_recomputable(data_info, is_recompute)
119
121
  if self.config.level == Const.LEVEL_L2:
120
122
  return
123
+ self.call_stack_collect(name)
121
124
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
122
125
 
123
- except Exception:
126
+ except Exception as e:
127
+ # 取异常类名作为“类型”做去重
128
+ error_type = type(e).__name__
124
129
  tb = traceback.format_exc()
125
130
  self.data_writer.write_error_log(
126
- f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
131
+ f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}",
132
+ error_type=error_type
127
133
  )
128
134
 
129
135
  def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
@@ -139,13 +145,15 @@ class DataCollector:
139
145
  self.set_is_recomputable(data_info, is_recompute)
140
146
  if self.config.level == Const.LEVEL_L2:
141
147
  return
142
- self.call_stack_collect(name)
143
148
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
144
149
 
145
- except Exception:
150
+ except Exception as e:
151
+ # 取异常类名作为“类型”做去重
152
+ error_type = type(e).__name__
146
153
  tb = traceback.format_exc()
147
154
  self.data_writer.write_error_log(
148
- f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
155
+ f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}",
156
+ error_type=error_type
149
157
  )
150
158
 
151
159
  def forward_data_collect_only_tensor(self, name, module, pid, module_input_output):
@@ -154,10 +162,13 @@ class DataCollector:
154
162
  return
155
163
  self.data_processor.analyze_forward(name, module, module_input_output)
156
164
 
157
- except Exception:
165
+ except Exception as e:
166
+ # 取异常类名作为“类型”做去重
167
+ error_type = type(e).__name__
158
168
  tb = traceback.format_exc()
159
169
  self.data_writer.write_error_log(
160
- f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
170
+ f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}",
171
+ error_type=error_type
161
172
  )
162
173
 
163
174
  def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
@@ -173,10 +184,12 @@ class DataCollector:
173
184
  self.call_stack_collect(name)
174
185
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
175
186
 
176
- except Exception:
187
+ except Exception as e:
188
+ error_type = type(e).__name__
177
189
  tb = traceback.format_exc()
178
190
  self.data_writer.write_error_log(
179
- f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}"
191
+ f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}",
192
+ error_type=error_type
180
193
  )
181
194
 
182
195
  def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None):
@@ -185,10 +198,12 @@ class DataCollector:
185
198
  return
186
199
  self.data_processor.analyze_backward(name, module, module_input_output)
187
200
 
188
- except Exception:
201
+ except Exception as e:
202
+ error_type = type(e).__name__
189
203
  tb = traceback.format_exc()
190
204
  self.data_writer.write_error_log(
191
- f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
205
+ f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}",
206
+ error_type=error_type
192
207
  )
193
208
 
194
209
  def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
@@ -206,10 +221,12 @@ class DataCollector:
206
221
  self.backward_module_names[module_name] = True
207
222
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
208
223
 
209
- except Exception:
224
+ except Exception as e:
225
+ error_type = type(e).__name__
210
226
  tb = traceback.format_exc()
211
227
  self.data_writer.write_error_log(
212
- f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}"
228
+ f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}",
229
+ error_type=error_type
213
230
  )
214
231
 
215
232
  def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
@@ -223,10 +240,12 @@ class DataCollector:
223
240
  self.set_is_recomputable(data_info, is_recompute)
224
241
  self.handle_data(name, data_info)
225
242
 
226
- except Exception:
243
+ except Exception as e:
244
+ error_type = type(e).__name__
227
245
  tb = traceback.format_exc()
228
246
  self.data_writer.write_error_log(
229
- f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
247
+ f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}",
248
+ error_type=error_type
230
249
  )
231
250
 
232
251
  def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
@@ -240,25 +259,32 @@ class DataCollector:
240
259
  self.set_is_recomputable(data_info, is_recompute)
241
260
  self.handle_data(name, data_info)
242
261
 
243
- except Exception:
262
+ except Exception as e:
263
+ error_type = type(e).__name__
244
264
  tb = traceback.format_exc()
245
265
  self.data_writer.write_error_log(
246
- f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
266
+ f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}",
267
+ error_type=error_type
247
268
  )
248
269
 
249
270
  def update_construct(self, name):
250
271
  if self.config.level not in DataCollector.level_without_construct:
251
272
  if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
252
273
  if self.optimizer_status_first_start[self.optimizer_status]:
253
- self.data_writer.update_construct({self.optimizer_status: None})
274
+ self.data_writer.update_construct(
275
+ {self.optimizer_status: None if not is_megatron() else [None, get_micro_step()]})
254
276
  self.optimizer_status_first_start[self.optimizer_status] = False
255
- self.data_writer.update_construct({name: self.optimizer_status})
277
+ self.data_writer.update_construct(
278
+ {name: self.optimizer_status if not is_megatron() else [self.optimizer_status, get_micro_step()]})
256
279
  else:
257
280
  if self.config.level == Const.LEVEL_MIX and \
258
281
  not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
259
282
  self.data_writer.update_construct(
260
283
  {name: self.module_processor.api_parent_node.get(threading.get_ident())}
261
284
  )
285
+ if MegatronStepInfo.is_megatron:
286
+ micro_step_number = max(MegatronStepInfo.forward_micro_step, MegatronStepInfo.backward_micro_step)
287
+ self.data_writer.update_construct({Const.MEGATRON_MICRO_STEP_NUMBER: micro_step_number})
262
288
 
263
289
  self.data_writer.update_construct(self.module_processor.module_node)
264
290
 
@@ -282,20 +308,36 @@ class DataCollector:
282
308
  self.data_processor.update_iter(current_iter)
283
309
 
284
310
  def params_data_collect(self, name, param_name, pid, data):
311
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
312
+ self.update_api_or_module_name(grad_name)
313
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
314
+ if self.data_writer.cache_data.get("data"):
315
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
316
+ self.params_grad_record[grad_name] = False
317
+ return
318
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
319
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
320
+ self.params_grad_record[grad_name] = False
321
+
322
+ def params_data_collect_in_bw_hook(self, params_dict, name):
285
323
  try:
286
- grad_name = name + Const.SEP + Const.PARAMS_GRAD
287
- self.update_api_or_module_name(grad_name)
288
- if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
289
- if self.data_writer.cache_data.get("data"):
290
- self.data_writer.cache_data.get("data").pop(grad_name, None)
324
+ if not params_dict:
291
325
  return
292
- data_info = self.data_processor.analyze_params(grad_name, param_name, data)
293
- self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
294
- except Exception:
326
+ ori_name = name.rsplit(Const.SEP, 2)[0]
327
+ for param_name, param in params_dict.items():
328
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
329
+ self.update_api_or_module_name(grad_name)
330
+ if self.params_grad_record.get(grad_name, False):
331
+ grad = param.grad if hasattr(param, "grad") else None
332
+ data_info = self.data_processor.analyze_params(grad_name, param_name, grad)
333
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
334
+ except Exception as e:
335
+ error_type = type(e).__name__
295
336
  tb = traceback.format_exc()
296
337
  self.data_writer.write_error_log(
297
- f"[ERROR] params_data_collect failed: "
298
- f"name={name}, param_name={param_name}, pid={pid}\n{tb}"
338
+ f"[ERROR] params_data_collect_in_bw_hook failed: "
339
+ f"name={name}",
340
+ error_type=error_type
299
341
  )
300
342
 
301
343
  def debug_data_collect_forward(self, variable, name_with_count):
@@ -94,6 +94,8 @@ class BaseDataProcessor:
94
94
  def __init__(self, config, data_writer):
95
95
  self.data_writer = data_writer
96
96
  self.config = config
97
+ if self.data_writer is not None:
98
+ self.data_writer.config = config
97
99
  self.api_info_struct = {}
98
100
  self.stack_info_struct = {}
99
101
  self.current_api_or_module_name = None