mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,16 +12,12 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- import itertools
16
- import math
17
15
  import re
18
- import statistics
19
16
 
20
17
  import torch
21
18
 
22
- from msprobe.core.common.const import MonitorConst
23
- from msprobe.pytorch.monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean
24
- from msprobe.core.common.log import logger
19
+ from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
+ from msprobe.pytorch.monitor.utils import NAN_TENSOR_ON_DEVICE
25
21
 
26
22
 
27
23
  def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
@@ -31,7 +27,9 @@ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
31
27
  return f"{module_or_param_name}/rank{rank}/{tag}"
32
28
 
33
29
 
34
- def squash_param_name(param_name):
30
+ def squash_param_name(param_name, enable=True):
31
+ if not enable:
32
+ return param_name
35
33
  name = ''
36
34
  for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
37
35
  match = re.findall(pattern, param_name)
@@ -63,7 +61,7 @@ class TensorMetrics:
63
61
  self.metrics = {} # tensor_tag --> []
64
62
  self.cur_idx = {}
65
63
 
66
- def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8):
64
+ def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank):
67
65
  """get stats and insert into metrics dictionary"""
68
66
  prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
69
67
  for stat_op in stat_ops:
@@ -120,14 +118,14 @@ class NormMetric(Metric):
120
118
  @staticmethod
121
119
  def get_metric_value(tensor, eps):
122
120
  return get_norm(tensor)
123
-
121
+
124
122
 
125
123
  @register_config_metric("zeros")
126
124
  class ZerosMetric(Metric):
127
125
  @staticmethod
128
126
  def get_metric_value(tensor, eps):
129
127
  return get_zeros(tensor, eps)
130
-
128
+
131
129
 
132
130
  @register_config_metric("nans")
133
131
  class NaNsMetric(Metric):
@@ -146,48 +144,29 @@ class IdentMetric(Metric):
146
144
 
147
145
 
148
146
  def get_metrics(ops, tag2tensor, eps, out_dict=None):
147
+ """
148
+ :param ops: ["op1", "op2"]
149
+ :param tag2tensor: {
150
+ '0:fc_0/input': torch.randn([3, 4]),
151
+ '0:fc_0/output': torch.randn([3, 3])
152
+ }
153
+ :param eps: float 1e-8
154
+ :param out_dict:{
155
+ '0:fc_0/input': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
156
+ '0:fc_0/output': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
157
+ }
158
+ :return: out_dict
159
+ """
149
160
  if out_dict is None:
150
161
  out_dict = {}
151
162
  for tag, tensor in tag2tensor.items():
152
163
  if tag not in out_dict:
153
164
  out_dict[tag] = {}
154
- for metric_name in ops:
165
+ if not torch.is_tensor(tensor):
166
+ # Non-tensor in/output filled with nan.
167
+ out_dict[tag].update({metric_name: NAN_TENSOR_ON_DEVICE for metric_name in ops})
168
+ continue
169
+ for metric_name in ops:
155
170
  fun_metric = config_metric_registry.get(metric_name)
156
171
  out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
157
172
  return out_dict
158
-
159
-
160
- def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''):
161
- if not metric_value:
162
- return
163
- tensors = []
164
- tags = list(itertools.product(metric_value.keys(), ops))
165
- for op2tensor in metric_value.values():
166
- tensors.extend(op2tensor.values())
167
- with torch.no_grad():
168
- metric_list = torch.stack(tensors).cpu()
169
- for tag, metric in zip(tags, metric_list):
170
- summary_writer.add_scalar(tag, metric, step)
171
-
172
-
173
- def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''):
174
- write_metrics_base(ops, summary_writer, metric_value, step, prefix='')
175
-
176
- if not summary_writer.header:
177
- # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
178
- if prefix in {"actv", "actv_grad"}:
179
- if prefix == "actv":
180
- input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
181
- else:
182
- input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
183
- ops_ = [MonitorConst.DOT.join(i[::-1]) for i in itertools.product(ops, input_and_output)]
184
- summary_writer.header = ["module_name", "step", *ops_]
185
- else:
186
- summary_writer.header = ["param_name", "step", *ops]
187
-
188
- for key in metric_value.keys():
189
- if MonitorConst.VPP_SEP in key:
190
- summary_writer.header.insert(0, 'vpp_stage')
191
- break
192
- summary_writer.write_csv(prefix, step)
193
- summary_writer.header = []
@@ -17,7 +17,7 @@ import re
17
17
  import abc
18
18
  import torch
19
19
 
20
- from msprobe.core.common.log import logger
20
+ from msprobe.pytorch.common.log import logger
21
21
 
22
22
  # 用于存储所有validator实现类的注册表
23
23
  config_validator_registry = {}
@@ -79,6 +79,8 @@ class TupleValidator(ConfigValidator):
79
79
 
80
80
  def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
81
81
  focused_col = None
82
+ if not config_spec or not isinstance(config_spec, str):
83
+ return focused_col
82
84
  for _, validator_cls in config_validator_registry.items():
83
85
  config_validator = validator_cls()
84
86
  pattern_match = config_validator.check_pattern_match(config_spec)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,13 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from abc import ABC, abstractmethod
17
16
  from collections import defaultdict
18
17
 
19
18
  import torch
20
19
  import torch.distributed as dist
21
20
 
22
- from msprobe.core.common.log import logger
21
+ from msprobe.pytorch.common.log import logger
23
22
  from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
24
23
 
25
24
 
@@ -87,7 +86,7 @@ class OptimizerMon(object):
87
86
  partition_id = dist.get_rank()
88
87
 
89
88
  def get_flatten_grad(self, optimizer, group_idx):
90
- if fp32_partitioned_groups_flat[group_idx].grad is None:
89
+ if fp32_partitioned_groups_flat[group_idx].grad is None:
91
90
  if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
92
91
  fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
93
92
  optimizer.averaged_gradients[group_idx],
@@ -151,29 +150,36 @@ class MixPrecisionOptimizerMon(OptimizerMon):
151
150
  混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
152
151
  """
153
152
 
153
+ def map_fp16_tp_fp32_param(self, mix_prec_opt):
154
+ for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
155
+ for fp16_param, fp32_param in zip(fp16_group, fp32_group):
156
+ self.fp16_to_fp32_param[fp16_param] = fp32_param
157
+
154
158
  def fetch_mv(self, monitor, torch_opt, params2name):
155
159
  mix_prec_opt = self.wrapped_optimizer
156
160
 
157
161
  if not self.fp16_to_fp32_param and mix_prec_opt is not None:
158
- for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
159
- for fp16_param, fp32_param in zip(fp16_group, fp32_group):
160
- self.fp16_to_fp32_param[fp16_param] = fp32_param
162
+ self.map_fp16_tp_fp32_param(mix_prec_opt)
163
+
161
164
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
162
165
 
163
166
 
164
167
  class MegatronDistributedOptimizerMon(OptimizerMon):
165
- def fetch_mv(self, monitor, torch_opt, params2name):
166
- mix_prec_opt = self.wrapped_optimizer
168
+ def map_fp16_tp_fp32_param(self, mix_prec_opt):
167
169
  if not (hasattr(mix_prec_opt, "model_float16_groups") and
168
170
  hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
169
171
  raise Exception(
170
172
  "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
171
173
  "if not, please check megatron-lm version")
174
+ for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
175
+ mix_prec_opt.shard_fp32_from_float16_groups):
176
+ for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
177
+ self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
178
+
179
+ def fetch_mv(self, monitor, torch_opt, params2name):
180
+ mix_prec_opt = self.wrapped_optimizer
172
181
  if not self.fp16_to_fp32_param and mix_prec_opt is not None:
173
- for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
174
- mix_prec_opt.shard_fp32_from_float16_groups):
175
- for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
176
- self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
182
+ self.map_fp16_tp_fp32_param(mix_prec_opt)
177
183
 
178
184
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
179
185
 
@@ -183,6 +189,36 @@ class MegatronFP32OptimizerMon(OptimizerMon):
183
189
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
184
190
 
185
191
 
192
+ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
193
+ def fetch_mv(self, monitor, torch_opt, params2name):
194
+ mix_prec_opt = self.wrapped_optimizer
195
+
196
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
197
+ for opt in mix_prec_opt.chained_optimizers:
198
+ self.map_fp16_tp_fp32_param(opt)
199
+
200
+ if not isinstance(torch_opt, torch.optim.Optimizer):
201
+ torch_opt.state = {}
202
+ for opt in mix_prec_opt.chained_optimizers:
203
+ torch_opt.state.update(opt.optimizer.state)
204
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
205
+
206
+
207
+ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
208
+ def fetch_mv(self, monitor, torch_opt, params2name):
209
+ mix_prec_opt = self.wrapped_optimizer
210
+
211
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
212
+ for opt in mix_prec_opt.chained_optimizers:
213
+ self.map_fp16_tp_fp32_param(opt)
214
+
215
+ if not isinstance(torch_opt, torch.optim.Optimizer):
216
+ torch_opt.state = {}
217
+ for opt in mix_prec_opt.chained_optimizers:
218
+ torch_opt.state.update(opt.optimizer.state)
219
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
220
+
221
+
186
222
  class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
187
223
  def fetch_mv(self, monitor, torch_opt, params2name):
188
224
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
@@ -271,13 +307,15 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
271
307
 
272
308
  class DummyOptimizerMon(OptimizerMon):
273
309
  def fetch_mv(self, monitor, torch_opt, params2name):
274
- return MVResult(exp_avg=None, exp_avg_sq=None, update=None, ratio=None)
310
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
275
311
 
276
312
 
277
313
  class OptimizerMonFactory:
278
314
  _optimizer_mon_map = {
279
315
  "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
280
316
  "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
317
+ "Megatron_ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
318
+ "Megatron_ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
281
319
  "Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
282
320
  "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
283
321
  "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
@@ -1,11 +1,26 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
1
17
  import os
2
18
  import re
3
- import argparse
4
19
  from glob import glob
5
20
 
6
21
  import pandas as pd
7
22
 
8
- from msprobe.core.common.log import logger
23
+ from msprobe.pytorch.common.log import logger
9
24
 
10
25
 
11
26
  def parse_logfile(logfile):
@@ -21,19 +36,19 @@ def parse_logfile(logfile):
21
36
  def parse_monitor_output(output_dir):
22
37
  reduced = {}
23
38
  unreduced = {}
24
- for dir in glob(output_dir + '*'):
25
- rank = int(re.findall('(?<=rank)[\d]*', dir)[0])
39
+ for directory in glob(output_dir + '*'):
40
+ rank = int(re.findall('(?<=rank)[\d]*', directory)[0])
26
41
  unreduced[rank] = []
27
42
  reduced[rank] = []
28
- for file in os.listdir(dir):
29
- df = pd.read_csv(os.path.join(dir, file))
43
+ for file in os.listdir(directory):
44
+ df = pd.read_csv(os.path.join(directory, file))
30
45
  if '_unreduced_' in file:
31
46
  unreduced[rank].append(df)
32
47
  pass
33
48
  elif '_reduced_' in file:
34
49
  reduced[rank].append(df)
35
50
  else:
36
- logger.info(f'unexpected file {file} in {dir}')
51
+ logger.info(f'unexpected file {file} in {directory}')
37
52
  return reduced, unreduced
38
53
 
39
54
 
@@ -41,7 +56,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
41
56
  steps = len(reduced[0])
42
57
  world_size = len(reduced)
43
58
  errors = []
44
- for index, row in unreduced[0][0].iterrows():
59
+ for _, row in unreduced[0][0].iterrows():
45
60
  param = row['param_name']
46
61
  is_tp_duplicate = False
47
62
  for step in range(2):
@@ -103,7 +118,7 @@ def valid_total_norm(total_norm, reduced, duplicate_embedding):
103
118
  if step == 0:
104
119
  logger.info(f'rank {rank} is duplicated in dp group')
105
120
  continue
106
- for index, row in reduced[rank][step].iterrows():
121
+ for _, row in reduced[rank][step].iterrows():
107
122
  if duplicate_embedding and 'word_embedding' in row['param_name']:
108
123
  continue
109
124
  calculated_norm += row['norm'] ** 2
@@ -16,13 +16,27 @@ import inspect
16
16
  from collections import namedtuple
17
17
  from datetime import timezone, timedelta
18
18
  from functools import wraps
19
+ from datetime import datetime
20
+ import os
21
+ import re
19
22
 
20
23
  import torch
21
24
 
22
25
  from msprobe.core.common.const import MonitorConst, Const
23
- from msprobe.core.common.log import logger
26
+ from msprobe.pytorch.common.log import logger
24
27
  from msprobe.core.common.utils import is_int
28
+ from msprobe.core.common.file_utils import check_file_or_directory_path
25
29
 
30
+
31
+ device = "cpu"
32
+ try:
33
+ import torch_npu
34
+ device = "npu"
35
+ except ImportError:
36
+ if torch.cuda.is_available():
37
+ device = "cuda"
38
+
39
+ NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
26
40
  FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
27
41
  FILE_NAME_MAX_LENGTH = 255
28
42
  DIRECTORY_MAX_LENGTH = 4096
@@ -39,6 +53,10 @@ class MsgConst:
39
53
  SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
40
54
 
41
55
 
56
+ def get_output_base_dir():
57
+ return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
58
+
59
+
42
60
  def filter_special_chars(func):
43
61
  @wraps(func)
44
62
  def func_level(msg):
@@ -109,15 +127,16 @@ def is_recomputation():
109
127
  def validate_ops(ops):
110
128
  if not isinstance(ops, list):
111
129
  raise TypeError("ops should be a list")
112
- if not ops:
113
- raise TypeError(f"specify ops to calculate metrics. Optional ops: {MonitorConst.OP_LIST}")
114
-
115
130
  valid_ops = []
116
131
  for op in ops:
117
132
  if op not in MonitorConst.OP_LIST:
118
133
  logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
119
- else:
120
- valid_ops.append(op)
134
+ continue
135
+ valid_ops.append(op)
136
+ if not valid_ops:
137
+ default_op = MonitorConst.OP_LIST[0]
138
+ valid_ops.append(default_op)
139
+ logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
121
140
  return valid_ops
122
141
 
123
142
 
@@ -164,6 +183,11 @@ def validate_mg_distribution(mg_distribution):
164
183
  raise TypeError('mg_distribution should be a bool')
165
184
 
166
185
 
186
+ def validate_param_distribution(param_distribution):
187
+ if not isinstance(param_distribution, bool):
188
+ raise TypeError('param_distribution should be a bool')
189
+
190
+
167
191
  def validate_cc_distribution(cc_distribution):
168
192
  if not isinstance(cc_distribution, dict):
169
193
  raise TypeError('cc_distribution should be a dictionary')
@@ -184,6 +208,11 @@ def validate_cc_distribution(cc_distribution):
184
208
  raise TypeError(f'{key} of cc_distribution is not supported.')
185
209
 
186
210
 
211
+ def validate_squash_name(squash_name):
212
+ if not isinstance(squash_name, bool):
213
+ raise TypeError('squash_name should be a bool')
214
+
215
+
187
216
  def validate_alert(alert):
188
217
  if not isinstance(alert, dict):
189
218
  raise TypeError('alert should be a dictionary')
@@ -240,6 +269,9 @@ def validate_config(config):
240
269
  mg_distribution = config.get('mg_distribution', False)
241
270
  validate_mg_distribution(mg_distribution)
242
271
 
272
+ param_distribution = config.get('param_distribution', False)
273
+ validate_param_distribution(param_distribution)
274
+
243
275
  cc_distribution = config.get('cc_distribution', {})
244
276
  validate_cc_distribution(cc_distribution)
245
277
 
@@ -248,3 +280,42 @@ def validate_config(config):
248
280
 
249
281
  step_count_per_record = config.get('step_count_per_record', 1)
250
282
  validate_step_count_per_record(step_count_per_record)
283
+
284
+ squash_name = config.get('squash_name', True)
285
+ validate_squash_name(squash_name)
286
+
287
+ if not targets:
288
+ if xy_distribution:
289
+ config["all_xy"] = True
290
+ config["targets"] = {"": {}}
291
+
292
+
293
+ def time_str2time_digit(time_str):
294
+ time_format = '%b%d_%H-%M-%S'
295
+ try:
296
+ time_digit = datetime.strptime(time_str, time_format)
297
+ except Exception as e:
298
+ raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
299
+ of existing output dirpath, like 'Dec03_21-34-40'.") from e
300
+ return time_digit
301
+
302
+
303
+ def get_target_output_dir(monitor_path, time_start, time_end):
304
+ check_file_or_directory_path(monitor_path, isdir=True)
305
+ time_start = time_str2time_digit(time_start) if time_start is not None else time_start
306
+ time_end = time_str2time_digit(time_end) if time_end is not None else time_end
307
+ if time_start and time_end and time_start > time_end:
308
+ raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
309
+ result = {}
310
+ for dirname in os.listdir(monitor_path):
311
+ match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
312
+ if not match:
313
+ continue
314
+ time_tag = match.group(1)
315
+ rank = match.group(2)
316
+ target_time = time_str2time_digit(time_tag)
317
+ start_ok = time_start is None or target_time >= time_start
318
+ end_ok = time_end is None or target_time <= time_end
319
+ if start_ok and end_ok:
320
+ result[rank] = os.path.join(monitor_path, dirname)
321
+ return result
@@ -56,7 +56,7 @@ class PtdbgDispatch(TorchDispatchMode):
56
56
 
57
57
  self.device_id = torch_npu._C._npu_getDevice()
58
58
  self.dump_mode = dump_mode
59
- self.dump_api_list = api_list
59
+ self.dump_api_list = api_list or []
60
60
  self.debug_flag = debug
61
61
  self.api_index = 0
62
62
  self.single_api_index_dict = {}
@@ -182,7 +182,13 @@ class PtdbgDispatch(TorchDispatchMode):
182
182
  npu_out_cpu = safe_get_value(npu_out_cpu, 0, "npu_out_cpu")
183
183
 
184
184
  with TimeStatistics("CPU RUN", run_param):
185
- cpu_out = func(*cpu_args, **cpu_kwargs)
185
+ try:
186
+ cpu_out = func(*cpu_args, **cpu_kwargs)
187
+ except RuntimeError as e:
188
+ self.api_index -= 1
189
+ logger.warning(f"RuntimeError: {e}")
190
+ logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
191
+ return npu_out
186
192
 
187
193
  if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]:
188
194
  cpu_out = cpu_out.float()
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,16 +12,17 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  import time
20
- import numpy as np
21
18
  from collections import namedtuple
22
- from msprobe.pytorch.parse_tool.lib.utils import Util
19
+
20
+ import numpy as np
21
+
22
+ from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files
23
23
  from msprobe.pytorch.parse_tool.lib.config import Const
24
24
  from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
25
- from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files
25
+ from msprobe.pytorch.parse_tool.lib.utils import Util
26
26
 
27
27
 
28
28
  class Compare:
@@ -126,7 +126,7 @@ class Compare:
126
126
  all_close = np.allclose(data_left, data_right, atol=al, rtol=rl)
127
127
  np.seterr(divide='raise')
128
128
  cos_sim = np.dot(data_left, data_right) / (
129
- np.sqrt(np.dot(data_left, data_left)) * np.sqrt(np.dot(data_right, data_right)))
129
+ np.sqrt(np.dot(data_left, data_left)) * np.sqrt(np.dot(data_right, data_right)))
130
130
  err_cnt = 0
131
131
  total_cnt = data_left.shape[0]
132
132
  diff_table_columns = ['Index', 'Left', 'Right', 'Diff']
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,14 +12,13 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
17
+
19
18
  import numpy as np
20
19
 
21
20
 
22
21
  class Const:
23
-
24
22
  MS_ACCU_CMP_PATH = '/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py'
25
23
  MS_ACCU_CMP_FILE_NAME = 'msaccucmp.py'
26
24
  ROOT_DIR = ""
@@ -1,4 +1,18 @@
1
- # coding=utf-8
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import os
3
17
 
4
18
 
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,13 +12,14 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
- import cmd
15
+
18
16
  import argparse
19
- from msprobe.pytorch.parse_tool.lib.parse_tool import ParseTool
20
- from msprobe.pytorch.parse_tool.lib.utils import Util
17
+ import cmd
18
+
21
19
  from msprobe.pytorch.parse_tool.lib.config import Const
22
20
  from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception
21
+ from msprobe.pytorch.parse_tool.lib.parse_tool import ParseTool
22
+ from msprobe.pytorch.parse_tool.lib.utils import Util
23
23
 
24
24
 
25
25
  class InteractiveCli(cmd.Cmd):
@@ -81,7 +81,7 @@ class InteractiveCli(cmd.Cmd):
81
81
  self.util.check_files_in_path(args.my_dump_path)
82
82
  self.util.check_files_in_path(args.golden_dump_path)
83
83
  if self.util.dir_contains_only(args.my_dump_path, ".npy") and \
84
- self.util.dir_contains_only(args.golden_dump_path, ".npy"):
84
+ self.util.dir_contains_only(args.golden_dump_path, ".npy"):
85
85
  self.parse_tool.do_compare_converted_dir(args)
86
86
  else:
87
87
  self.parse_tool.do_vector_compare(args)
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,13 +12,13 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import logging
17
+
18
18
  from msprobe.core.common.exceptions import FileCheckException
19
19
 
20
20
 
21
21
  class ParseException(Exception):
22
-
23
22
  PARSE_INVALID_PATH_ERROR = 0
24
23
  PARSE_NO_FILE_ERROR = 1
25
24
  PARSE_NO_MODULE_ERROR = 2
@@ -51,4 +50,5 @@ def catch_exception(func):
51
50
  except FileCheckException:
52
51
  log.error("Command execution failed")
53
52
  return result
53
+
54
54
  return inner