mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -676,6 +676,7 @@ torch:
676
676
  - _batch_norm_impl_index
677
677
  - _convolution
678
678
  - _foreach_norm
679
+ - _fused_adamw_
679
680
  - _softmax_backward_data
680
681
  - abs
681
682
  - abs_
@@ -1258,6 +1259,7 @@ torch_npu:
1258
1259
  - npu_scatter_nd_update_
1259
1260
  - npu_scatter_nd_update
1260
1261
  - npu_prefetch
1262
+ - npu_dynamic_block_quant
1261
1263
 
1262
1264
  aten:
1263
1265
  - signbit
@@ -2009,4 +2011,23 @@ distributed:
2009
2011
 
2010
2012
  npu_distributed:
2011
2013
  - isend
2012
- - irecv
2014
+ - irecv
2015
+
2016
+ mindspeed:
2017
+ - dropout_add_layer_norm.npu_dropout_add_layer_norm
2018
+ - npu_rotary_position_embedding.npu_rotary_position_embedding
2019
+ - fusion_attention_v2.npu_fusion_attention
2020
+ - npu_mm_all_reduce_add_rms_norm.npu_mm_all_reduce_add_rms_norm
2021
+ - npu_mm_all_reduce_add_rms_norm_.npu_mm_all_reduce_add_rms_norm_
2022
+ - gmm.npu_gmm
2023
+ - gmm.npu_gmm_v2
2024
+ - npu_grouped_mat_mul_all_reduce.npu_grouped_mat_mul_all_reduce
2025
+ - ffn.npu_ffn
2026
+ - npu_moe_token_permute.npu_moe_token_permute
2027
+ - npu_moe_token_unpermute.npu_moe_token_unpermute
2028
+ - npu_ring_attention_update.npu_ring_attention_update
2029
+ - npu_matmul_add.npu_matmul_add_fp32
2030
+ - npu_groupmatmul_add.npu_groupmatmul_add_fp32
2031
+ - quant_gmm.npu_quant_gmm
2032
+ - quant_gmm.npu_quant_gmm_v2
2033
+ - npu_apply_fused_ema_adamw.npu_apply_fused_ema_adamw
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,11 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
- from msprobe.core.common.file_utils import load_yaml
17
+ import importlib
18
+ import inspect
19
+
20
+ from msprobe.core.common.file_utils import load_yaml, check_link
21
+ from msprobe.core.common.log import logger
18
22
 
19
23
 
20
24
  def get_ops():
@@ -26,3 +30,25 @@ def get_ops():
26
30
  wrap_torch = ops.get('torch')
27
31
  wrap_npu_ops = ops.get('torch_npu')
28
32
  return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch) | set(wrap_npu_ops)
33
+
34
+
35
+ def dynamic_import_op(package, white_list):
36
+ package_name = package.__name__
37
+ ops = {}
38
+ ops_dir, _ = os.path.split(package.__file__)
39
+ check_link(ops_dir)
40
+ for file_name in os.listdir(ops_dir):
41
+ if file_name in white_list:
42
+ sub_module_name = file_name[:-3]
43
+ module_name = f"{package_name}.{sub_module_name}"
44
+ try:
45
+ module = importlib.import_module(module_name)
46
+ except Exception as e:
47
+ logger.warning(f"import {module_name} failed!")
48
+ continue
49
+
50
+ func_members = inspect.getmembers(module, inspect.isfunction)
51
+ for func_member in func_members:
52
+ func_name, func = func_member[0], func_member[1]
53
+ ops[f"{sub_module_name}.{func_name}"] = func
54
+ return ops
@@ -28,7 +28,10 @@ from msprobe.core.common.decorator import recursion_depth_decorator
28
28
  from msprobe.pytorch.common.log import logger
29
29
  from msprobe.pytorch.monitor.utils import get_target_output_dir
30
30
 
31
- all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]
31
+ all_data_type_list = [
32
+ "actv", "actv_grad", "exp_avg", "exp_avg_sq",
33
+ "grad_unreduced", "grad_reduced", "param_origin", "param_updated"
34
+ ]
32
35
  CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
33
36
  MAX_PROCESS_NUM = 128
34
37
 
@@ -48,7 +51,7 @@ def parse_step_line(line, ops):
48
51
 
49
52
  def parse_step_fn(filepath):
50
53
  data = read_csv(filepath)
51
- ops = [k for k in data.keys() if k in MonitorConst.OP_LIST]
54
+ ops = [k for k in data.keys() if k in MonitorConst.OP_LIST[:-2]]
52
55
  parse_step_result = {}
53
56
 
54
57
  for _, line in data.iterrows():
@@ -76,6 +79,7 @@ def write_step(output_dirpath, parse_step_result, rank, data_type):
76
79
  for op, value in ops.items():
77
80
  tag = f"{vpp_name}/{op}"
78
81
  writer.add_scalar(tag, value, step)
82
+ writer.flush()
79
83
 
80
84
 
81
85
  @recursion_depth_decorator("update_dict", max_depth=50)
@@ -0,0 +1,259 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import itertools
16
+ import os
17
+ from collections import defaultdict
18
+ from dataclasses import dataclass
19
+
20
+ import pandas as pd
21
+ import torch
22
+ from torch.utils.tensorboard import SummaryWriter
23
+
24
+ from msprobe.core.common.const import FileCheckConst, MonitorConst
25
+ from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
26
+ from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner
27
+ from msprobe.pytorch.common.log import logger
28
+
29
+
30
+ class BCOLORS:
31
+ HEADER = '\033[95m'
32
+ OKBLUE = '\033[94m'
33
+ OKCYAN = '\033[96m'
34
+ OKGREEN = '\033[92m'
35
+ WARNING = '\033[93m'
36
+ FAIL = '\033[91m'
37
+ ENDC = '\033[0m'
38
+ BOLD = '\033[1m'
39
+ UNDERLINE = '\033[4m'
40
+
41
+
42
+ @dataclass
43
+ class WriterInput:
44
+ path: str
45
+ ad_rules: list
46
+ job_id: str
47
+ anomaly_factory: AnomalyDataFactory = None
48
+ ndigits: int = 6
49
+ step_count_per_record: int = 1
50
+
51
+
52
+ class BaseWriterWithAD:
53
+ def __init__(self, writer_input: WriterInput):
54
+ self.tag2scalars = {}
55
+ self.ad_rules = writer_input.ad_rules
56
+ self.job_id = writer_input.job_id
57
+ self.anomaly_factory = writer_input.anomaly_factory
58
+ self.anomalies = []
59
+ self.ndigits = writer_input.ndigits
60
+ self.beta = 0.99
61
+
62
+ @staticmethod
63
+ def stack_tensors(tensor_list):
64
+ """
65
+ Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group,
66
+ stack them separately, migrate xpu_group to cpu, and then restore in the order of input.
67
+
68
+ :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')]
69
+ :return: result: list of float
70
+ """
71
+ cpu_tensors = []
72
+ xpu_tensors = []
73
+
74
+ for tensor in tensor_list:
75
+ if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu':
76
+ # 将device上的tensor先stack后to cpu
77
+ xpu_tensors.append(tensor)
78
+ else:
79
+ cpu_tensors.append(tensor)
80
+
81
+ xpu_stack = torch.stack(xpu_tensors).cpu() if xpu_tensors else torch.tensor([])
82
+
83
+ # 按照输入的顺序恢复
84
+ result = []
85
+ cpu_tensors_idx, xpu_tensors_idx = 0, 0
86
+ for tensor in tensor_list:
87
+ if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu':
88
+ result.append(xpu_stack[xpu_tensors_idx])
89
+ xpu_tensors_idx += 1
90
+ else:
91
+ result.append(cpu_tensors[cpu_tensors_idx])
92
+ cpu_tensors_idx += 1
93
+
94
+ return result
95
+
96
+ def get_anomalies(self):
97
+ """返回已检测到的异常列表
98
+ """
99
+ return self.anomalies
100
+
101
+ def clear_anomalies(self):
102
+ self.anomalies.clear()
103
+
104
+ def add_scalar(self, tag, scalar_value, global_step=None):
105
+ """If an anomaly is detected, the anomaly information is recorded and added to self.anomalies.
106
+ Args:
107
+ tag (tuple): tuple of tag_name and tag like ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min').
108
+ scalar_value (float): scalar_value.
109
+ global_step (int): global_step.
110
+ Returns:
111
+ None
112
+ """
113
+ if not self.ad_rules or tag[-1] in ["shape", "dtype"]:
114
+ return
115
+ if isinstance(scalar_value, torch.Tensor):
116
+ scalar_value = scalar_value.item()
117
+ avg = self._update_tag2scalars(tag, scalar_value)
118
+ detected, rule_name = self._ad(scalar_value, history=avg)
119
+ if detected:
120
+ if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]:
121
+ return
122
+ exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, "
123
+ f"current value {scalar_value}, history mean {avg}.")
124
+ logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}")
125
+ # append to self.anomalies for dump
126
+ if self.anomaly_factory:
127
+ self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
128
+
129
+ def write_metrics(self, ops, metric_value, step, prefix=''):
130
+ if not metric_value:
131
+ return
132
+ tensors = []
133
+ tags = list(itertools.product(metric_value.keys(), ops))
134
+ for op2tensor in metric_value.values():
135
+ tensors.extend(op2tensor.values())
136
+ if not tensors:
137
+ return
138
+
139
+ n_slices = len(tensors) // MonitorConst.SLICE_SIZE
140
+ with torch.no_grad():
141
+ for i in range(n_slices + 1):
142
+ begin = i * MonitorConst.SLICE_SIZE
143
+ end = (i + 1) * MonitorConst.SLICE_SIZE
144
+ if begin == len(tensors):
145
+ continue
146
+ metric_list = self.stack_tensors(tensors[begin:end])
147
+ for tag, metric in zip(tags[begin:end], metric_list):
148
+ self.add_scalar(tag, metric, step)
149
+
150
+ def _ad(self, scalar_value, history):
151
+ return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
152
+
153
+ def _update_tag2scalars(self, tag, scalar_value):
154
+ """Update the average and count of a scalar value associated with a tag.
155
+
156
+ This method is used to maintain a running average of scalar values for each tag.
157
+
158
+
159
+ Args:
160
+ tag (str): The tag identifier.
161
+ scalar_value (float): The scalar value to be added.
162
+
163
+ Returns:
164
+ float: The average value before update.
165
+ """
166
+ abs_scalar_value = abs(scalar_value)
167
+ if tag not in self.tag2scalars:
168
+ self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0}
169
+ avg = self.tag2scalars[tag]['avg']
170
+ self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value
171
+ self.tag2scalars[tag]['count'] += 1
172
+ return avg
173
+
174
+
175
+ class CSVWriterWithAD(BaseWriterWithAD):
176
+ def __init__(self, writer_input: WriterInput):
177
+ super().__init__(writer_input)
178
+
179
+ path = writer_input.path
180
+ self.log_dir = path
181
+ create_directory(path)
182
+ change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY)
183
+ self.context_dict = defaultdict(list)
184
+ self.header = []
185
+ self.step_count_per_record = writer_input.step_count_per_record
186
+
187
+ def get_step_interval(self, step):
188
+ count = step // self.step_count_per_record
189
+ return count * self.step_count_per_record, (count + 1) * self.step_count_per_record - 1
190
+
191
+ def write_csv(self, prefix, step):
192
+ """
193
+ Args:
194
+ prefix[str]: prefix of output csv file e.g. grad_unreduced
195
+ step[int]
196
+ """
197
+ if len(self.context_dict) == 0:
198
+ return
199
+
200
+ ster_start, step_end = self.get_step_interval(step)
201
+ filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
202
+ if not os.path.exists(filepath):
203
+ data_frame = pd.DataFrame(columns=self.header)
204
+ write_df_to_csv(data_frame, filepath)
205
+
206
+ new_data = []
207
+ for name, metric_value in self.context_dict.items():
208
+ new_line = name.split(MonitorConst.NAME_SEP) + metric_value
209
+ new_line.insert(2, step)
210
+ new_data.append(new_line)
211
+ new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
212
+ write_df_to_csv(new_data, filepath, mode='a+', header=False)
213
+ self.context_dict = defaultdict(list)
214
+
215
+ def add_scalar(self, tag, scalar_value, global_step):
216
+ """
217
+ ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min')
218
+ """
219
+ super().add_scalar(tag, scalar_value, global_step)
220
+
221
+ name = tag[0].split('/')[0]
222
+ if isinstance(scalar_value, torch.Tensor):
223
+ value = scalar_value.item()
224
+ elif isinstance(scalar_value, torch.Size):
225
+ value = list(scalar_value)
226
+ else:
227
+ value = scalar_value
228
+ self.context_dict[name].append(value)
229
+
230
+ def write_metrics(self, ops, metric_value, step, prefix='', **kwargs):
231
+ super().write_metrics(ops, metric_value, step, prefix='')
232
+
233
+ if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"):
234
+ self.header = MonitorConst.CSV_HEADER_MICRO_STEP + ops
235
+ else:
236
+ self.header = MonitorConst.CSV_HEADER + ops
237
+ self.write_csv(prefix, step)
238
+
239
+ def close(self):
240
+ pass
241
+
242
+
243
+ class SummaryWriterWithAD(SummaryWriter, BaseWriterWithAD):
244
+ def __init__(self, writer_input: WriterInput):
245
+
246
+ path = writer_input.path
247
+ if not os.path.exists(path):
248
+ create_directory(path)
249
+ try:
250
+ super(SummaryWriter, self).__init__(writer_input)
251
+ super().__init__(path)
252
+ except Exception as e:
253
+ logger.error(f'error when init summary writer at {path}: {e}')
254
+ raise ValueError("Init summary writer error.") from e
255
+
256
+ def add_scalar(self, tag, scalar_value, global_step):
257
+ super(SummaryWriter, self).add_scalar(tag, scalar_value, global_step)
258
+ tag = f'{tag[0]}_{tag[1]}'
259
+ super().add_scalar(tag, scalar_value, global_step)