mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,237 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import itertools
17
+ import os
18
+ from dataclasses import dataclass
19
+ from collections import defaultdict
20
+
21
+ import pandas as pd
22
+ from mindspore import ops
23
+ from mindspore import Tensor
24
+ from mindspore import _no_grad
25
+
26
+ from msprobe.core.common.log import logger
27
+ from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
28
+ from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner
29
+ from msprobe.core.common.const import FileCheckConst, MonitorConst
30
+
31
+
32
+ class BCOLORS:
33
+ HEADER = '\033[95m'
34
+ OKBLUE = '\033[94m'
35
+ OKCYAN = '\033[96m'
36
+ OKGREEN = '\033[92m'
37
+ WARNING = '\033[93m'
38
+ FAIL = '\033[91m'
39
+ ENDC = '\033[0m'
40
+ BOLD = '\033[1m'
41
+ UNDERLINE = '\033[4m'
42
+
43
+
44
+ @dataclass
45
+ class WriterInput:
46
+ path: str
47
+ ad_rules: list
48
+ job_id: str
49
+ anomaly_factory: AnomalyDataFactory = None
50
+ ndigits: int = 6
51
+ step_count_per_record: int = 1
52
+
53
+
54
+ class BaseWriterWithAD:
55
+ def __init__(self, writer_input: WriterInput):
56
+ self.tag2scalars = {}
57
+ self.ad_rules = writer_input.ad_rules
58
+ self.job_id = writer_input.job_id
59
+ self.anomaly_factory = writer_input.anomaly_factory
60
+ self.anomalies = []
61
+ self.ndigits = writer_input.ndigits
62
+ self.beta = 0.99
63
+
64
+ @staticmethod
65
+ def stack_tensors(tensor_list):
66
+ """
67
+ Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group,
68
+ stack them separately, migrate xpu_group to cpu, and then restore in the order of input.
69
+
70
+ :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')]
71
+ :return: result: list of float
72
+ """
73
+ cpu_tensors = []
74
+ xpu_tensors = []
75
+
76
+ for tensor in tensor_list:
77
+ if isinstance(tensor, Tensor):
78
+ # 将device上的tensor先stack后to cpu
79
+ xpu_tensors.append(tensor)
80
+ else:
81
+ cpu_tensors.append(tensor)
82
+
83
+ xpu_stack = ops.stack(xpu_tensors).tolist() if xpu_tensors else ops.tensor([])
84
+
85
+ # 按照输入的顺序恢复
86
+ result = []
87
+ cpu_tensors_idx, xpu_tensors_idx = 0, 0
88
+ for tensor in tensor_list:
89
+ if isinstance(tensor, Tensor):
90
+ result.append(xpu_stack[xpu_tensors_idx])
91
+ xpu_tensors_idx += 1
92
+ else:
93
+ result.append(cpu_tensors[cpu_tensors_idx])
94
+ cpu_tensors_idx += 1
95
+
96
+ return result
97
+
98
+ def get_anomalies(self):
99
+ """返回已检测到的异常列表
100
+ """
101
+ return self.anomalies
102
+
103
+ def clear_anomalies(self):
104
+ self.anomalies.clear()
105
+
106
+ def add_scalar(self, tag, scalar_value, global_step=None, need_explain=False):
107
+ """If an anomaly is detected, the anomaly information is recorded and added to self.anomalies.
108
+ Args:
109
+ tag (tuple): tuple of tag_name and tag like ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min').
110
+ scalar_value (float): scalar_value.
111
+ global_step (int): global_step.
112
+ Returns:
113
+ None
114
+ """
115
+ if not self.ad_rules or tag[-1] in ["shape", "dtype"]:
116
+ return
117
+ if isinstance(scalar_value, Tensor):
118
+ scalar_value = scalar_value.item()
119
+ avg = self._update_tag2scalars(tag, scalar_value)
120
+ detected, rule_name = self._ad(scalar_value, history=avg)
121
+ if detected:
122
+ if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]:
123
+ return
124
+ exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, "
125
+ f"current value {scalar_value}, history mean {avg}.")
126
+ logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}")
127
+ # append to self.anomalies for dump
128
+ if self.anomaly_factory:
129
+ self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
130
+
131
+ def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False):
132
+ if not metric_value:
133
+ return
134
+ tensors = []
135
+ tags = list(itertools.product(metric_value.keys(), op_list))
136
+ for op2tensor in metric_value.values():
137
+ tensors.extend(op2tensor.values())
138
+
139
+ if not tensors:
140
+ return
141
+
142
+ with _no_grad():
143
+ metric_list = self.stack_tensors(tensors)
144
+ for tag, metric in zip(tags, metric_list):
145
+ self.add_scalar(tag, metric, step, need_explain)
146
+
147
+ def _ad(self, scalar_value, history):
148
+ return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
149
+
150
+ def _update_tag2scalars(self, tag, scalar_value):
151
+ """Update the average and count of a scalar value associated with a tag.
152
+
153
+ This method is used to maintain a running average of scalar values for each tag.
154
+
155
+
156
+ Args:
157
+ tag (str): The tag identifier.
158
+ scalar_value (float): The scalar value to be added.
159
+
160
+ Returns:
161
+ float: The average value before update.
162
+ """
163
+ abs_scalar_value = abs(scalar_value)
164
+ if tag not in self.tag2scalars:
165
+ self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0}
166
+ avg = self.tag2scalars[tag]['avg']
167
+ self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value
168
+ self.tag2scalars[tag]['count'] += 1
169
+ return avg
170
+
171
+
172
+ class CSVWriterWithAD(BaseWriterWithAD):
173
+ def __init__(self, writer_input: WriterInput):
174
+ super().__init__(writer_input)
175
+
176
+ path = writer_input.path
177
+ self.log_dir = path
178
+ create_directory(path)
179
+ change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY)
180
+ self.context_dict = defaultdict(list)
181
+ self.header = []
182
+ self.step_count_per_record = writer_input.step_count_per_record
183
+
184
+ def get_step_interval(self, step):
185
+ count = step // self.step_count_per_record
186
+ return count * self.step_count_per_record, (count + 1) * self.step_count_per_record - 1
187
+
188
+ def write_csv(self, prefix, step):
189
+ """
190
+ Args:
191
+ prefix[str]: prefix of output csv file e.g. grad_unreduced
192
+ step[int]
193
+ """
194
+ if len(self.context_dict) == 0:
195
+ return
196
+
197
+ ster_start, step_end = self.get_step_interval(step)
198
+ filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
199
+ if not os.path.exists(filepath):
200
+ data_frame = pd.DataFrame(columns=self.header)
201
+ write_df_to_csv(data_frame, filepath)
202
+
203
+ new_data = []
204
+ for name, metric_value in self.context_dict.items():
205
+ new_line = name.split(MonitorConst.NAME_SEP) + metric_value
206
+ new_line.insert(2, step)
207
+ new_data.append(new_line)
208
+ new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
209
+ write_df_to_csv(new_data, filepath, mode='a+', header=False)
210
+ self.context_dict = defaultdict(list)
211
+
212
+ def add_scalar(self, tag, scalar_value, global_step, need_explain=False):
213
+ """
214
+ ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min')
215
+ """
216
+ super().add_scalar(tag, scalar_value, global_step, need_explain=False)
217
+ split_name = tag[0].split('/')
218
+ name = split_name[0]
219
+ if need_explain:
220
+ if 'pre' in split_name[-1]:
221
+ name += '.input'
222
+ if 'post' in split_name[-1]:
223
+ name += '.output'
224
+ self.context_dict[name].append(scalar_value)
225
+
226
+ def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False, **kwargs):
227
+ need_explain = prefix == 'other'
228
+ super().write_metrics(op_list, metric_value, step, prefix='', need_explain=need_explain)
229
+
230
+ if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"):
231
+ self.header = MonitorConst.CSV_HEADER_MICRO_STEP + op_list
232
+ else:
233
+ self.header = MonitorConst.CSV_HEADER + op_list
234
+ self.write_csv(prefix, step)
235
+
236
+ def close(self):
237
+ pass
@@ -46,6 +46,8 @@ def get_max(x: Tensor):
46
46
 
47
47
  @_no_grad()
48
48
  def get_zeros(x: Tensor, eps: float):
49
+ if x.numel() == 0:
50
+ return Tensor(float('nan'))
49
51
  return mint.sum(mint.abs(x) < eps) / x.numel()
50
52
 
51
53
 
@@ -54,10 +56,21 @@ def get_nans(t):
54
56
  return ops.isnan(t.astype(mstype.float32)).sum()
55
57
 
56
58
 
57
- FUNC_MAP = {"min" : get_min,
58
- "max" : get_max,
59
- "mean" : get_mean,
60
- "norm" : get_norm,
61
- "nans" : get_nans,
62
- "zeros": get_zeros
63
- }
59
+ def get_shape(t):
60
+ return t.shape
61
+
62
+
63
+ def get_dtype(t):
64
+ return t.dtype
65
+
66
+
67
+ FUNC_MAP = {
68
+ "min": get_min,
69
+ "max": get_max,
70
+ "mean": get_mean,
71
+ "norm": get_norm,
72
+ "nans": get_nans,
73
+ "zeros": get_zeros,
74
+ "shape": get_shape,
75
+ "dtype": get_dtype
76
+ }