mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,201 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- import sys
18
- import argparse
19
- import ast
20
- import heapq
21
-
22
- from msprobe.pytorch.common.log import logger
23
- from msprobe.core.common.const import MonitorConst
24
- from msprobe.core.common.file_utils import check_path_before_create, save_json, create_directory, remove_path, \
25
- check_file_or_directory_path, load_json
26
- from msprobe.pytorch.monitor.anomaly_detect import GradAnomalyData
27
-
28
-
29
- class AnomalyDataWriter:
30
- """
31
- 异常数据写入类,负责将异常数据写入到JSON文件中。
32
- """
33
-
34
- def __init__(self, dump_path, rank) -> None:
35
- self.dump_path = dump_path
36
- self.dump_rank_dir = os.path.join(self.dump_path, f"rank{rank}")
37
- self.json_path = os.path.join(self.dump_rank_dir, MonitorConst.ANOMALY_JSON)
38
-
39
- @staticmethod
40
- def get_anomaly_dict(anomalies):
41
- """将GradAnomalyData列表转换为json"""
42
- anomalies_json = {}
43
- for anomaly in anomalies:
44
- anomalies_json.update({anomaly.get_key(): anomaly.to_dict()})
45
- return anomalies_json
46
-
47
- def init_detected_json(self):
48
- """初始化落盘文件"""
49
- check_path_before_create(self.dump_path)
50
- if not os.path.exists(self.dump_path):
51
- create_directory(self.dump_path)
52
-
53
- if not os.path.exists(self.dump_rank_dir):
54
- create_directory(self.dump_rank_dir)
55
-
56
- if os.path.exists(self.json_path):
57
- check_file_or_directory_path(self.json_path, isdir=False)
58
- logger.warning(f"The existing file will be deleted: {self.json_path}.")
59
- remove_path(self.json_path)
60
- save_json(self.json_path, {}, indent=1)
61
-
62
- def write_detected_json(self, anomalies):
63
- """
64
- 落盘异常数据
65
- Args:
66
- anomalies: GradAnomalyData对象列表
67
- """
68
- anomalies_json = self.get_anomaly_dict(anomalies)
69
- logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.")
70
-
71
- data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {}
72
- data_to_write.update(anomalies_json)
73
- save_json(self.json_path, data_to_write, indent=1)
74
-
75
-
76
- class AnomalyDataLoader:
77
- def __init__(self, data_path) -> None:
78
- self.data_path = data_path
79
-
80
- @staticmethod
81
- def create_instances_from_dict(anomalies_dict: dict):
82
- instances = []
83
- for values in anomalies_dict.values():
84
- try:
85
- instances.append(GradAnomalyData(**values))
86
- except KeyError as e:
87
- logger.warning(f"Missing key in anomaly data: {e}.")
88
- except Exception as e:
89
- logger.warning(f"Value error when creating a GradAnomalyData instance: {e}.")
90
- return instances
91
-
92
- def get_anomalies_from_jsons(self):
93
- """遍历文件夹,从rankK/anomaly.json中读取异常数据
94
- return: anomalies: GradAnomalyData对象列表
95
- """
96
- anomalies = []
97
- check_file_or_directory_path(self.data_path, isdir=True)
98
- for rank_dir in os.listdir(self.data_path):
99
- rank_path = os.path.join(self.data_path, rank_dir)
100
- if not os.path.isdir(rank_path):
101
- continue
102
- json_path = os.path.join(rank_path, MonitorConst.ANOMALY_JSON)
103
- if not os.path.exists(json_path):
104
- continue
105
- data_anomalies = load_json(json_path)
106
- instances = self.create_instances_from_dict(data_anomalies)
107
- anomalies.extend(instances)
108
- return anomalies
109
-
110
-
111
- class AnomalyAnalyse:
112
- def __init__(self) -> None:
113
- self.sorted_anomalies = []
114
-
115
- def get_range_top_k(self, topk, step_list, anomalies):
116
- """
117
- 获取前topk个step_list范围内的异常。
118
- """
119
- if not step_list:
120
- filtered_anomalies = anomalies
121
- else:
122
- filtered_anomalies = [
123
- anomaly
124
- for anomaly in anomalies
125
- if anomaly.step in step_list
126
- ]
127
- if topk >= len(filtered_anomalies):
128
- self.sorted_anomalies = sorted(filtered_anomalies)
129
- else:
130
- self.sorted_anomalies = list(heapq.nsmallest(topk, filtered_anomalies))
131
- return self.sorted_anomalies
132
-
133
- def rewrite_sorted_anomalies(self, output_path):
134
- """
135
- 将排序后的异常数据重新落盘
136
- """
137
- check_file_or_directory_path(output_path, isdir=True)
138
-
139
- sorted_data = AnomalyDataWriter.get_anomaly_dict(self.sorted_anomalies)
140
- logger.info(f"{MonitorConst.ANALYSE_JSON} is at {output_path}.")
141
- json_path = os.path.join(output_path, MonitorConst.ANALYSE_JSON)
142
- if os.path.exists(json_path):
143
- logger.warning(f"The existing file will be deleted: {json_path}.")
144
- remove_path(json_path)
145
- save_json(json_path, sorted_data, indent=1)
146
-
147
-
148
- def _get_parse_args():
149
- parser = argparse.ArgumentParser()
150
- parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str,
151
- help="<Required> The anomaly detect result dictionary: generate from monitor tool.",
152
- required=True,
153
- )
154
- parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
155
- help="<optional> The analyse task result out path.",
156
- required=False,
157
- )
158
- parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int,
159
- help="<optional> Top K number of earliest anomalies.",
160
- required=False,
161
- )
162
- parser.add_argument("-s", "--step", dest="step_list", default="[]", type=str,
163
- help="<optional> Analyse which steps.",
164
- required=False,
165
- )
166
- return parser.parse_args(sys.argv[1:])
167
-
168
-
169
- def _get_step_and_stop(args):
170
- try:
171
- step_list = ast.literal_eval(args.step_list)
172
- if not isinstance(step_list, list):
173
- raise ValueError(f"{args.step_list} is not a list.")
174
- except (ValueError, SyntaxError, RecursionError) as e:
175
- raise Exception(f"The step list must be a resolvable list type.") from e
176
- if args.top_k_number <= 0:
177
- raise Exception("The top k number must be greater than 0.")
178
- return step_list, args.top_k_number
179
-
180
-
181
- def _anomaly_analyse():
182
- args = _get_parse_args()
183
- step_list, top_k_number = _get_step_and_stop(args)
184
- loader = AnomalyDataLoader(args.data_path_dir)
185
- anomalies = loader.get_anomalies_from_jsons()
186
- analyser = AnomalyAnalyse()
187
- top_anomalies = analyser.get_range_top_k(
188
- top_k_number, step_list, anomalies
189
- )
190
- analyser.rewrite_sorted_anomalies(
191
- args.out_path if args.out_path else args.data_path_dir
192
- )
193
-
194
- logger.info(f"Top {top_k_number} anomalies are listed as follows:")
195
- for index, anomaly in enumerate(top_anomalies):
196
- logger.info(f"{index}: {anomaly.message}")
197
-
198
-
199
- if __name__ == "__main__":
200
- _anomaly_analyse()
201
- logger.info("Analyse task completed.")
@@ -1,410 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import itertools
16
- import os
17
- import statistics as st
18
- import sys
19
- from abc import ABC
20
- from collections import defaultdict
21
- from dataclasses import dataclass, field
22
- from typing import List
23
-
24
- import pandas as pd
25
- import torch
26
- from torch.utils.tensorboard import SummaryWriter
27
-
28
- from msprobe.core.common.const import FileCheckConst, MonitorConst
29
- from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
30
- from msprobe.pytorch.common.log import logger
31
-
32
-
33
- class ScanRule(ABC):
34
- name = "ScanRule"
35
-
36
- def apply(self, history, cur):
37
- raise NotImplementedError("abstract method apply is not implemented")
38
-
39
-
40
- class AnomalyTurbulence(ScanRule):
41
- name = "AnomalyTurbulence"
42
-
43
- def __init__(self, threshold) -> None:
44
- self.threshold = threshold
45
-
46
- def apply(self, history, cur):
47
- baseline = st.mean(history) if isinstance(history, list) else history
48
-
49
- up_bound = baseline + baseline * self.threshold
50
- if baseline > 0:
51
- return cur > up_bound
52
- else:
53
- return cur < up_bound
54
-
55
-
56
- class AnomalyScanner:
57
-
58
- @staticmethod
59
- def load_rules(specs: List[dict]):
60
- """
61
- specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}]
62
- """
63
- if specs is None:
64
- return []
65
- alert_rules = []
66
- for spec in specs:
67
- # 使用get方法获取键值,如果键不存在则返回None
68
- rule_cls_name = spec.get("rule_name")
69
- rule_args = spec.get("args")
70
-
71
- # 检查必要的键是否存在
72
- if rule_cls_name is None or rule_args is None:
73
- logger.warning(f"Spec is missing required keys: {spec}")
74
- continue
75
-
76
- cur_module = sys.modules.get(__name__)
77
- try:
78
- rule_cls = getattr(cur_module, rule_cls_name)
79
- except AttributeError:
80
- logger.error(f"Rule class '{rule_cls_name}' not found in the current module.")
81
- continue
82
-
83
- try:
84
- rule_instance = rule_cls(**rule_args)
85
- alert_rules.append(rule_instance)
86
- except Exception as e:
87
- logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}")
88
- continue
89
-
90
- return alert_rules
91
-
92
- @staticmethod
93
- def scan(scan_rules: List[ScanRule], history, cur):
94
- anomaly = False
95
- for rule in scan_rules:
96
- anomaly = rule.apply(history, cur)
97
- if anomaly:
98
- return anomaly, rule.name
99
- return anomaly, None
100
-
101
-
102
- class BCOLORS:
103
- HEADER = '\033[95m'
104
- OKBLUE = '\033[94m'
105
- OKCYAN = '\033[96m'
106
- OKGREEN = '\033[92m'
107
- WARNING = '\033[93m'
108
- FAIL = '\033[91m'
109
- ENDC = '\033[0m'
110
- BOLD = '\033[1m'
111
- UNDERLINE = '\033[4m'
112
-
113
-
114
- class AnomalyDataFactory(ABC):
115
- def __init__(self, rank, pp_stage, group_mates):
116
- super().__init__()
117
- self.rank = rank
118
- self.pp_stage = pp_stage
119
- self.group_mates = group_mates
120
- self.micro_step = 0
121
- self.name2callid = {}
122
-
123
- def set_call_id(self, name2callid):
124
- """根据当前GradContext信息更新call_id vpp_stage等信息
125
- """
126
- self.name2callid = name2callid
127
-
128
- def create(self, tag, message, step):
129
- """如果检查出异常, 调用当前接口生成GradAnomalyData实例
130
- tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min')
131
- message (str): anomaly detect message
132
- step (int): training step
133
- """
134
- if not isinstance(tag, tuple) or len(tag) != 2:
135
- raise ValueError("tag must be a tuple with length 2")
136
- tag_name = tag[0]
137
- param_name = tag_name.split('/')[0]
138
- call_id = self.name2callid.get(tag_name, -1)
139
- if MonitorConst.NAME_SEP in param_name:
140
- vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0])
141
- else:
142
- vpp_stage = 0
143
-
144
- return GradAnomalyData(
145
- self.rank,
146
- step,
147
- self.micro_step,
148
- self.pp_stage,
149
- vpp_stage,
150
- call_id,
151
- tag_name,
152
- message,
153
- self.group_mates
154
- )
155
-
156
-
157
- class TrainStage:
158
- DEFAULT_STAGE = -1
159
- FORWARD_STAGE = 0
160
- BACKWARD_STAGE = 1
161
- OPTIMIZER_STAGE = 2
162
-
163
-
164
- FORWARD_KEY = [MonitorConst.ACTV]
165
- BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD,
166
- MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD]
167
- OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ]
168
- TRAIN_STAGE = {
169
- **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY},
170
- **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY},
171
- **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY}
172
- }
173
-
174
-
175
- @dataclass(eq=True)
176
- class GradAnomalyData:
177
- rank: int = 0
178
- step: int = 0
179
- micro_step: int = 0
180
- pp_stage: int = 0
181
- vpp_stage: int = 0
182
- call_id: int = 0
183
- tag_name: str = field(default=None, compare=False)
184
- message: str = field(default="", compare=False)
185
- group_mates: list = field(default=None, compare=False)
186
-
187
- def __lt__(self, other):
188
- """
189
- 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。
190
- 比较规则为:
191
- step 和 micro_step 值越小优先级越高;
192
- vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高;
193
- call_id 值越小优先级越高。
194
- """
195
- if not isinstance(other, GradAnomalyData):
196
- return NotImplemented
197
-
198
- self_train_stage = self.get_train_stage(self.tag_name)
199
- other_train_stage = self.get_train_stage(other.tag_name)
200
-
201
- def vpp_pp_comparator(anomaly):
202
- """
203
- Determine the priority rule for vpp and pp based on train stage
204
- Forward stage prefers smaller vpp and pp
205
- Other stages prefer larger vpp and pp
206
- """
207
- if self_train_stage == TrainStage.FORWARD_STAGE:
208
- return anomaly.vpp_stage, anomaly.pp_stage
209
- else:
210
- return -anomaly.vpp_stage, -anomaly.pp_stage
211
-
212
- self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id]
213
- other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id]
214
- return self_cmp < other_cmp
215
-
216
- def __le__(self, other):
217
- if not isinstance(other, GradAnomalyData):
218
- return NotImplemented
219
- return self == other or self < other
220
-
221
- @staticmethod
222
- def get_train_stage(tag_name):
223
- """
224
- :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq"
225
- :return: int, if forward return 0; if backward return 1; if optimizer return 2
226
- """
227
- key_ = tag_name.split("/")[-1]
228
- return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE)
229
-
230
- def to_dict(self):
231
- return self.__dict__
232
-
233
- def get_key(self):
234
- # 0:1.self_attention.core_attention_flash_0/rank0/input_grad
235
- return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)])
236
-
237
-
238
- @dataclass
239
- class WriterInput:
240
- path: str
241
- ad_rules: list
242
- job_id: str
243
- anomaly_factory: AnomalyDataFactory = None
244
- ndigits: int = 6
245
- step_count_per_record: int = 1
246
-
247
-
248
- class BaseWriterWithAD:
249
- def __init__(self, writer_input: WriterInput):
250
- self.tag2scalars = {}
251
- self.ad_rules = writer_input.ad_rules
252
- self.job_id = writer_input.job_id
253
- self.anomaly_factory = writer_input.anomaly_factory
254
- self.anomalies = []
255
- self.ndigits = writer_input.ndigits
256
-
257
- def get_anomalies(self):
258
- """返回已检测到的异常列表
259
- """
260
- return self.anomalies
261
-
262
- def clear_anomalies(self):
263
- self.anomalies.clear()
264
-
265
- def add_scalar(self, tag, scalar_value, global_step=None):
266
- """If an anomaly is detected, the anomaly information is recorded and added to self.anomalies.
267
- Args:
268
- tag (tuple): tuple of tag_name and tag like ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min').
269
- scalar_value (float): scalar_value.
270
- global_step (int): global_step.
271
- Returns:
272
- None
273
- """
274
- detected = False
275
- if self.ad_rules:
276
- avg = self._update_tag2scalars(tag, scalar_value)
277
- detected, rule_name = self._ad(scalar_value, history=avg)
278
- if detected:
279
- exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}."
280
- logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}")
281
- # append to self.anomalies for dump
282
- if self.anomaly_factory:
283
- self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
284
-
285
- def write_metrics(self, ops, metric_value, step, prefix=''):
286
- if not metric_value:
287
- return
288
- tensors = []
289
- tags = list(itertools.product(metric_value.keys(), ops))
290
- for op2tensor in metric_value.values():
291
- tensors.extend(op2tensor.values())
292
- if not tensors:
293
- return
294
-
295
- n_slices = len(tensors) // MonitorConst.SLICE_SIZE
296
- with torch.no_grad():
297
- for i in range(n_slices + 1):
298
- begin = i * MonitorConst.SLICE_SIZE
299
- end = (i+1) * MonitorConst.SLICE_SIZE
300
- if begin == len(tensors):
301
- continue
302
- metric_list = torch.stack(tensors[begin:end]).cpu()
303
- for tag, metric in zip(tags[begin:end], metric_list):
304
- self.add_scalar(tag, metric, step)
305
-
306
- def _ad(self, scalar_value, history):
307
- return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
308
-
309
- def _update_tag2scalars(self, tag, scalar_value):
310
- """Update the average and count of a scalar value associated with a tag.
311
-
312
- This method is used to maintain a running average of scalar values for each tag.
313
-
314
-
315
- Args:
316
- tag (str): The tag identifier.
317
- scalar_value (float): The scalar value to be added.
318
-
319
- Returns:
320
- float: The average value before update.
321
- """
322
- if tag not in self.tag2scalars:
323
- self.tag2scalars[tag] = {'avg': scalar_value, 'count': 0}
324
- avg = self.tag2scalars[tag]['avg']
325
- new_avg = (avg * self.tag2scalars[tag]['count'] + scalar_value) / (self.tag2scalars[tag]['count'] + 1)
326
- self.tag2scalars[tag]['avg'] = new_avg
327
- self.tag2scalars[tag]['count'] += 1
328
- return avg
329
-
330
-
331
- class CSVWriterWithAD(BaseWriterWithAD):
332
- def __init__(self, writer_input: WriterInput):
333
- super().__init__(writer_input)
334
-
335
- path = writer_input.path
336
- self.log_dir = path
337
- create_directory(path)
338
- change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY)
339
- self.context_dict = defaultdict(list)
340
- self.header = []
341
- self.step_count_per_record = writer_input.step_count_per_record
342
-
343
- def get_step_interval(self, step):
344
- count = step // self.step_count_per_record
345
- return count * self.step_count_per_record, (count + 1) * self.step_count_per_record - 1
346
-
347
- def write_csv(self, prefix, step):
348
- """
349
- Args:
350
- prefix[str]: prefix of output csv file e.g. grad_unreduced
351
- step[int]
352
- """
353
- if len(self.context_dict) == 0:
354
- return
355
-
356
- ster_start, step_end = self.get_step_interval(step)
357
- filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
358
- if not os.path.exists(filepath):
359
- data_frame = pd.DataFrame(columns=self.header)
360
- write_df_to_csv(data_frame, filepath)
361
-
362
- new_data = []
363
- for name, metric_value in self.context_dict.items():
364
- new_line = name.split(MonitorConst.NAME_SEP) + metric_value
365
- new_line.insert(2, step)
366
- new_data.append(new_line)
367
-
368
- new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
369
- write_df_to_csv(new_data, filepath, mode='a+', header=False)
370
- self.context_dict = defaultdict(list)
371
-
372
- def add_scalar(self, tag, scalar_value, global_step):
373
- """
374
- ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min')
375
- """
376
- super().add_scalar(tag, scalar_value, global_step)
377
-
378
- name = tag[0].split('/')[0]
379
- self.context_dict[name].append(scalar_value.item())
380
-
381
- def write_metrics(self, ops, metric_value, step, prefix=''):
382
- super().write_metrics(ops, metric_value, step, prefix='')
383
-
384
- if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD]:
385
- self.header = MonitorConst.CSV_HEADER_XY + ops
386
- else:
387
- self.header = MonitorConst.CSV_HEADER + ops
388
- self.write_csv(prefix, step)
389
-
390
- def close(self):
391
- pass
392
-
393
-
394
- class SummaryWriterWithAD(SummaryWriter, BaseWriterWithAD):
395
- def __init__(self, writer_input: WriterInput):
396
-
397
- path = writer_input.path
398
- if not os.path.exists(path):
399
- create_directory(path)
400
- try:
401
- super(SummaryWriter, self).__init__(writer_input)
402
- super().__init__(path)
403
- except Exception as e:
404
- logger.error(f'error when init summary writer at {path}: {e}')
405
- raise ValueError("Init summary writer error.") from e
406
-
407
- def add_scalar(self, tag, scalar_value, global_step):
408
- super(SummaryWriter, self).add_scalar(tag, scalar_value, global_step)
409
- tag = f'{tag[0]}_{tag[1]}'
410
- super().add_scalar(tag, scalar_value, global_step)