mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,16 +12,12 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
import itertools
|
|
16
|
-
import math
|
|
17
15
|
import re
|
|
18
|
-
import statistics
|
|
19
16
|
|
|
20
17
|
import torch
|
|
21
18
|
|
|
22
|
-
from msprobe.
|
|
23
|
-
from msprobe.pytorch.monitor.
|
|
24
|
-
from msprobe.core.common.log import logger
|
|
19
|
+
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
20
|
+
from msprobe.pytorch.monitor.utils import NAN_TENSOR_ON_DEVICE
|
|
25
21
|
|
|
26
22
|
|
|
27
23
|
def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
@@ -31,7 +27,9 @@ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
|
31
27
|
return f"{module_or_param_name}/rank{rank}/{tag}"
|
|
32
28
|
|
|
33
29
|
|
|
34
|
-
def squash_param_name(param_name):
|
|
30
|
+
def squash_param_name(param_name, enable=True):
|
|
31
|
+
if not enable:
|
|
32
|
+
return param_name
|
|
35
33
|
name = ''
|
|
36
34
|
for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
|
|
37
35
|
match = re.findall(pattern, param_name)
|
|
@@ -63,7 +61,7 @@ class TensorMetrics:
|
|
|
63
61
|
self.metrics = {} # tensor_tag --> []
|
|
64
62
|
self.cur_idx = {}
|
|
65
63
|
|
|
66
|
-
def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank
|
|
64
|
+
def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank):
|
|
67
65
|
"""get stats and insert into metrics dictionary"""
|
|
68
66
|
prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
|
|
69
67
|
for stat_op in stat_ops:
|
|
@@ -120,14 +118,14 @@ class NormMetric(Metric):
|
|
|
120
118
|
@staticmethod
|
|
121
119
|
def get_metric_value(tensor, eps):
|
|
122
120
|
return get_norm(tensor)
|
|
123
|
-
|
|
121
|
+
|
|
124
122
|
|
|
125
123
|
@register_config_metric("zeros")
|
|
126
124
|
class ZerosMetric(Metric):
|
|
127
125
|
@staticmethod
|
|
128
126
|
def get_metric_value(tensor, eps):
|
|
129
127
|
return get_zeros(tensor, eps)
|
|
130
|
-
|
|
128
|
+
|
|
131
129
|
|
|
132
130
|
@register_config_metric("nans")
|
|
133
131
|
class NaNsMetric(Metric):
|
|
@@ -146,48 +144,29 @@ class IdentMetric(Metric):
|
|
|
146
144
|
|
|
147
145
|
|
|
148
146
|
def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
147
|
+
"""
|
|
148
|
+
:param ops: ["op1", "op2"]
|
|
149
|
+
:param tag2tensor: {
|
|
150
|
+
'0:fc_0/input': torch.randn([3, 4]),
|
|
151
|
+
'0:fc_0/output': torch.randn([3, 3])
|
|
152
|
+
}
|
|
153
|
+
:param eps: float 1e-8
|
|
154
|
+
:param out_dict:{
|
|
155
|
+
'0:fc_0/input': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
|
|
156
|
+
'0:fc_0/output': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
|
|
157
|
+
}
|
|
158
|
+
:return: out_dict
|
|
159
|
+
"""
|
|
149
160
|
if out_dict is None:
|
|
150
161
|
out_dict = {}
|
|
151
162
|
for tag, tensor in tag2tensor.items():
|
|
152
163
|
if tag not in out_dict:
|
|
153
164
|
out_dict[tag] = {}
|
|
154
|
-
|
|
165
|
+
if not torch.is_tensor(tensor):
|
|
166
|
+
# Non-tensor in/output filled with nan.
|
|
167
|
+
out_dict[tag].update({metric_name: NAN_TENSOR_ON_DEVICE for metric_name in ops})
|
|
168
|
+
continue
|
|
169
|
+
for metric_name in ops:
|
|
155
170
|
fun_metric = config_metric_registry.get(metric_name)
|
|
156
171
|
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
157
172
|
return out_dict
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''):
|
|
161
|
-
if not metric_value:
|
|
162
|
-
return
|
|
163
|
-
tensors = []
|
|
164
|
-
tags = list(itertools.product(metric_value.keys(), ops))
|
|
165
|
-
for op2tensor in metric_value.values():
|
|
166
|
-
tensors.extend(op2tensor.values())
|
|
167
|
-
with torch.no_grad():
|
|
168
|
-
metric_list = torch.stack(tensors).cpu()
|
|
169
|
-
for tag, metric in zip(tags, metric_list):
|
|
170
|
-
summary_writer.add_scalar(tag, metric, step)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''):
|
|
174
|
-
write_metrics_base(ops, summary_writer, metric_value, step, prefix='')
|
|
175
|
-
|
|
176
|
-
if not summary_writer.header:
|
|
177
|
-
# 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
|
|
178
|
-
if prefix in {"actv", "actv_grad"}:
|
|
179
|
-
if prefix == "actv":
|
|
180
|
-
input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
|
|
181
|
-
else:
|
|
182
|
-
input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
|
|
183
|
-
ops_ = [MonitorConst.DOT.join(i[::-1]) for i in itertools.product(ops, input_and_output)]
|
|
184
|
-
summary_writer.header = ["module_name", "step", *ops_]
|
|
185
|
-
else:
|
|
186
|
-
summary_writer.header = ["param_name", "step", *ops]
|
|
187
|
-
|
|
188
|
-
for key in metric_value.keys():
|
|
189
|
-
if MonitorConst.VPP_SEP in key:
|
|
190
|
-
summary_writer.header.insert(0, 'vpp_stage')
|
|
191
|
-
break
|
|
192
|
-
summary_writer.write_csv(prefix, step)
|
|
193
|
-
summary_writer.header = []
|
|
@@ -17,7 +17,7 @@ import re
|
|
|
17
17
|
import abc
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from msprobe.
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
21
|
|
|
22
22
|
# 用于存储所有validator实现类的注册表
|
|
23
23
|
config_validator_registry = {}
|
|
@@ -79,6 +79,8 @@ class TupleValidator(ConfigValidator):
|
|
|
79
79
|
|
|
80
80
|
def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
|
|
81
81
|
focused_col = None
|
|
82
|
+
if not config_spec or not isinstance(config_spec, str):
|
|
83
|
+
return focused_col
|
|
82
84
|
for _, validator_cls in config_validator_registry.items():
|
|
83
85
|
config_validator = validator_cls()
|
|
84
86
|
pattern_match = config_validator.check_pattern_match(config_spec)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,13 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from abc import ABC, abstractmethod
|
|
17
16
|
from collections import defaultdict
|
|
18
17
|
|
|
19
18
|
import torch
|
|
20
19
|
import torch.distributed as dist
|
|
21
20
|
|
|
22
|
-
from msprobe.
|
|
21
|
+
from msprobe.pytorch.common.log import logger
|
|
23
22
|
from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
|
|
24
23
|
|
|
25
24
|
|
|
@@ -87,7 +86,7 @@ class OptimizerMon(object):
|
|
|
87
86
|
partition_id = dist.get_rank()
|
|
88
87
|
|
|
89
88
|
def get_flatten_grad(self, optimizer, group_idx):
|
|
90
|
-
if
|
|
89
|
+
if fp32_partitioned_groups_flat[group_idx].grad is None:
|
|
91
90
|
if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
|
|
92
91
|
fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
|
|
93
92
|
optimizer.averaged_gradients[group_idx],
|
|
@@ -151,29 +150,36 @@ class MixPrecisionOptimizerMon(OptimizerMon):
|
|
|
151
150
|
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
152
151
|
"""
|
|
153
152
|
|
|
153
|
+
def map_fp16_tp_fp32_param(self, mix_prec_opt):
|
|
154
|
+
for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
|
|
155
|
+
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
156
|
+
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
157
|
+
|
|
154
158
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
155
159
|
mix_prec_opt = self.wrapped_optimizer
|
|
156
160
|
|
|
157
161
|
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
162
|
+
self.map_fp16_tp_fp32_param(mix_prec_opt)
|
|
163
|
+
|
|
161
164
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
162
165
|
|
|
163
166
|
|
|
164
167
|
class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
165
|
-
def
|
|
166
|
-
mix_prec_opt = self.wrapped_optimizer
|
|
168
|
+
def map_fp16_tp_fp32_param(self, mix_prec_opt):
|
|
167
169
|
if not (hasattr(mix_prec_opt, "model_float16_groups") and
|
|
168
170
|
hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
|
|
169
171
|
raise Exception(
|
|
170
172
|
"megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
|
|
171
173
|
"if not, please check megatron-lm version")
|
|
174
|
+
for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
|
|
175
|
+
mix_prec_opt.shard_fp32_from_float16_groups):
|
|
176
|
+
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
177
|
+
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
178
|
+
|
|
179
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
180
|
+
mix_prec_opt = self.wrapped_optimizer
|
|
172
181
|
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
173
|
-
|
|
174
|
-
mix_prec_opt.shard_fp32_from_float16_groups):
|
|
175
|
-
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
176
|
-
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
182
|
+
self.map_fp16_tp_fp32_param(mix_prec_opt)
|
|
177
183
|
|
|
178
184
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
179
185
|
|
|
@@ -183,6 +189,36 @@ class MegatronFP32OptimizerMon(OptimizerMon):
|
|
|
183
189
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
184
190
|
|
|
185
191
|
|
|
192
|
+
class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
193
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
194
|
+
mix_prec_opt = self.wrapped_optimizer
|
|
195
|
+
|
|
196
|
+
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
197
|
+
for opt in mix_prec_opt.chained_optimizers:
|
|
198
|
+
self.map_fp16_tp_fp32_param(opt)
|
|
199
|
+
|
|
200
|
+
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
201
|
+
torch_opt.state = {}
|
|
202
|
+
for opt in mix_prec_opt.chained_optimizers:
|
|
203
|
+
torch_opt.state.update(opt.optimizer.state)
|
|
204
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
208
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
209
|
+
mix_prec_opt = self.wrapped_optimizer
|
|
210
|
+
|
|
211
|
+
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
212
|
+
for opt in mix_prec_opt.chained_optimizers:
|
|
213
|
+
self.map_fp16_tp_fp32_param(opt)
|
|
214
|
+
|
|
215
|
+
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
216
|
+
torch_opt.state = {}
|
|
217
|
+
for opt in mix_prec_opt.chained_optimizers:
|
|
218
|
+
torch_opt.state.update(opt.optimizer.state)
|
|
219
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
220
|
+
|
|
221
|
+
|
|
186
222
|
class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
|
|
187
223
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
188
224
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
@@ -271,13 +307,15 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
271
307
|
|
|
272
308
|
class DummyOptimizerMon(OptimizerMon):
|
|
273
309
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
274
|
-
return
|
|
310
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
275
311
|
|
|
276
312
|
|
|
277
313
|
class OptimizerMonFactory:
|
|
278
314
|
_optimizer_mon_map = {
|
|
279
315
|
"Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
|
|
280
316
|
"Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
317
|
+
"Megatron_ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
318
|
+
"Megatron_ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
|
|
281
319
|
"Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
|
|
282
320
|
"DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
|
|
283
321
|
"DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
|
|
@@ -1,11 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
1
17
|
import os
|
|
2
18
|
import re
|
|
3
|
-
import argparse
|
|
4
19
|
from glob import glob
|
|
5
20
|
|
|
6
21
|
import pandas as pd
|
|
7
22
|
|
|
8
|
-
from msprobe.
|
|
23
|
+
from msprobe.pytorch.common.log import logger
|
|
9
24
|
|
|
10
25
|
|
|
11
26
|
def parse_logfile(logfile):
|
|
@@ -21,19 +36,19 @@ def parse_logfile(logfile):
|
|
|
21
36
|
def parse_monitor_output(output_dir):
|
|
22
37
|
reduced = {}
|
|
23
38
|
unreduced = {}
|
|
24
|
-
for
|
|
25
|
-
rank = int(re.findall('(?<=rank)[\d]*',
|
|
39
|
+
for directory in glob(output_dir + '*'):
|
|
40
|
+
rank = int(re.findall('(?<=rank)[\d]*', directory)[0])
|
|
26
41
|
unreduced[rank] = []
|
|
27
42
|
reduced[rank] = []
|
|
28
|
-
for file in os.listdir(
|
|
29
|
-
df = pd.read_csv(os.path.join(
|
|
43
|
+
for file in os.listdir(directory):
|
|
44
|
+
df = pd.read_csv(os.path.join(directory, file))
|
|
30
45
|
if '_unreduced_' in file:
|
|
31
46
|
unreduced[rank].append(df)
|
|
32
47
|
pass
|
|
33
48
|
elif '_reduced_' in file:
|
|
34
49
|
reduced[rank].append(df)
|
|
35
50
|
else:
|
|
36
|
-
logger.info(f'unexpected file {file} in {
|
|
51
|
+
logger.info(f'unexpected file {file} in {directory}')
|
|
37
52
|
return reduced, unreduced
|
|
38
53
|
|
|
39
54
|
|
|
@@ -41,7 +56,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
|
|
|
41
56
|
steps = len(reduced[0])
|
|
42
57
|
world_size = len(reduced)
|
|
43
58
|
errors = []
|
|
44
|
-
for
|
|
59
|
+
for _, row in unreduced[0][0].iterrows():
|
|
45
60
|
param = row['param_name']
|
|
46
61
|
is_tp_duplicate = False
|
|
47
62
|
for step in range(2):
|
|
@@ -103,7 +118,7 @@ def valid_total_norm(total_norm, reduced, duplicate_embedding):
|
|
|
103
118
|
if step == 0:
|
|
104
119
|
logger.info(f'rank {rank} is duplicated in dp group')
|
|
105
120
|
continue
|
|
106
|
-
for
|
|
121
|
+
for _, row in reduced[rank][step].iterrows():
|
|
107
122
|
if duplicate_embedding and 'word_embedding' in row['param_name']:
|
|
108
123
|
continue
|
|
109
124
|
calculated_norm += row['norm'] ** 2
|
msprobe/pytorch/monitor/utils.py
CHANGED
|
@@ -16,13 +16,27 @@ import inspect
|
|
|
16
16
|
from collections import namedtuple
|
|
17
17
|
from datetime import timezone, timedelta
|
|
18
18
|
from functools import wraps
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
import os
|
|
21
|
+
import re
|
|
19
22
|
|
|
20
23
|
import torch
|
|
21
24
|
|
|
22
25
|
from msprobe.core.common.const import MonitorConst, Const
|
|
23
|
-
from msprobe.
|
|
26
|
+
from msprobe.pytorch.common.log import logger
|
|
24
27
|
from msprobe.core.common.utils import is_int
|
|
28
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
25
29
|
|
|
30
|
+
|
|
31
|
+
device = "cpu"
|
|
32
|
+
try:
|
|
33
|
+
import torch_npu
|
|
34
|
+
device = "npu"
|
|
35
|
+
except ImportError:
|
|
36
|
+
if torch.cuda.is_available():
|
|
37
|
+
device = "cuda"
|
|
38
|
+
|
|
39
|
+
NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
|
|
26
40
|
FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
|
|
27
41
|
FILE_NAME_MAX_LENGTH = 255
|
|
28
42
|
DIRECTORY_MAX_LENGTH = 4096
|
|
@@ -39,6 +53,10 @@ class MsgConst:
|
|
|
39
53
|
SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
|
|
40
54
|
|
|
41
55
|
|
|
56
|
+
def get_output_base_dir():
|
|
57
|
+
return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
58
|
+
|
|
59
|
+
|
|
42
60
|
def filter_special_chars(func):
|
|
43
61
|
@wraps(func)
|
|
44
62
|
def func_level(msg):
|
|
@@ -109,15 +127,16 @@ def is_recomputation():
|
|
|
109
127
|
def validate_ops(ops):
|
|
110
128
|
if not isinstance(ops, list):
|
|
111
129
|
raise TypeError("ops should be a list")
|
|
112
|
-
if not ops:
|
|
113
|
-
raise TypeError(f"specify ops to calculate metrics. Optional ops: {MonitorConst.OP_LIST}")
|
|
114
|
-
|
|
115
130
|
valid_ops = []
|
|
116
131
|
for op in ops:
|
|
117
132
|
if op not in MonitorConst.OP_LIST:
|
|
118
133
|
logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
|
|
119
|
-
|
|
120
|
-
|
|
134
|
+
continue
|
|
135
|
+
valid_ops.append(op)
|
|
136
|
+
if not valid_ops:
|
|
137
|
+
default_op = MonitorConst.OP_LIST[0]
|
|
138
|
+
valid_ops.append(default_op)
|
|
139
|
+
logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
|
|
121
140
|
return valid_ops
|
|
122
141
|
|
|
123
142
|
|
|
@@ -164,6 +183,11 @@ def validate_mg_distribution(mg_distribution):
|
|
|
164
183
|
raise TypeError('mg_distribution should be a bool')
|
|
165
184
|
|
|
166
185
|
|
|
186
|
+
def validate_param_distribution(param_distribution):
|
|
187
|
+
if not isinstance(param_distribution, bool):
|
|
188
|
+
raise TypeError('param_distribution should be a bool')
|
|
189
|
+
|
|
190
|
+
|
|
167
191
|
def validate_cc_distribution(cc_distribution):
|
|
168
192
|
if not isinstance(cc_distribution, dict):
|
|
169
193
|
raise TypeError('cc_distribution should be a dictionary')
|
|
@@ -184,6 +208,11 @@ def validate_cc_distribution(cc_distribution):
|
|
|
184
208
|
raise TypeError(f'{key} of cc_distribution is not supported.')
|
|
185
209
|
|
|
186
210
|
|
|
211
|
+
def validate_squash_name(squash_name):
|
|
212
|
+
if not isinstance(squash_name, bool):
|
|
213
|
+
raise TypeError('squash_name should be a bool')
|
|
214
|
+
|
|
215
|
+
|
|
187
216
|
def validate_alert(alert):
|
|
188
217
|
if not isinstance(alert, dict):
|
|
189
218
|
raise TypeError('alert should be a dictionary')
|
|
@@ -240,6 +269,9 @@ def validate_config(config):
|
|
|
240
269
|
mg_distribution = config.get('mg_distribution', False)
|
|
241
270
|
validate_mg_distribution(mg_distribution)
|
|
242
271
|
|
|
272
|
+
param_distribution = config.get('param_distribution', False)
|
|
273
|
+
validate_param_distribution(param_distribution)
|
|
274
|
+
|
|
243
275
|
cc_distribution = config.get('cc_distribution', {})
|
|
244
276
|
validate_cc_distribution(cc_distribution)
|
|
245
277
|
|
|
@@ -248,3 +280,42 @@ def validate_config(config):
|
|
|
248
280
|
|
|
249
281
|
step_count_per_record = config.get('step_count_per_record', 1)
|
|
250
282
|
validate_step_count_per_record(step_count_per_record)
|
|
283
|
+
|
|
284
|
+
squash_name = config.get('squash_name', True)
|
|
285
|
+
validate_squash_name(squash_name)
|
|
286
|
+
|
|
287
|
+
if not targets:
|
|
288
|
+
if xy_distribution:
|
|
289
|
+
config["all_xy"] = True
|
|
290
|
+
config["targets"] = {"": {}}
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def time_str2time_digit(time_str):
|
|
294
|
+
time_format = '%b%d_%H-%M-%S'
|
|
295
|
+
try:
|
|
296
|
+
time_digit = datetime.strptime(time_str, time_format)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
|
|
299
|
+
of existing output dirpath, like 'Dec03_21-34-40'.") from e
|
|
300
|
+
return time_digit
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def get_target_output_dir(monitor_path, time_start, time_end):
|
|
304
|
+
check_file_or_directory_path(monitor_path, isdir=True)
|
|
305
|
+
time_start = time_str2time_digit(time_start) if time_start is not None else time_start
|
|
306
|
+
time_end = time_str2time_digit(time_end) if time_end is not None else time_end
|
|
307
|
+
if time_start and time_end and time_start > time_end:
|
|
308
|
+
raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
|
|
309
|
+
result = {}
|
|
310
|
+
for dirname in os.listdir(monitor_path):
|
|
311
|
+
match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
|
|
312
|
+
if not match:
|
|
313
|
+
continue
|
|
314
|
+
time_tag = match.group(1)
|
|
315
|
+
rank = match.group(2)
|
|
316
|
+
target_time = time_str2time_digit(time_tag)
|
|
317
|
+
start_ok = time_start is None or target_time >= time_start
|
|
318
|
+
end_ok = time_end is None or target_time <= time_end
|
|
319
|
+
if start_ok and end_ok:
|
|
320
|
+
result[rank] = os.path.join(monitor_path, dirname)
|
|
321
|
+
return result
|
|
@@ -56,7 +56,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
56
56
|
|
|
57
57
|
self.device_id = torch_npu._C._npu_getDevice()
|
|
58
58
|
self.dump_mode = dump_mode
|
|
59
|
-
self.dump_api_list = api_list
|
|
59
|
+
self.dump_api_list = api_list or []
|
|
60
60
|
self.debug_flag = debug
|
|
61
61
|
self.api_index = 0
|
|
62
62
|
self.single_api_index_dict = {}
|
|
@@ -182,7 +182,13 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
182
182
|
npu_out_cpu = safe_get_value(npu_out_cpu, 0, "npu_out_cpu")
|
|
183
183
|
|
|
184
184
|
with TimeStatistics("CPU RUN", run_param):
|
|
185
|
-
|
|
185
|
+
try:
|
|
186
|
+
cpu_out = func(*cpu_args, **cpu_kwargs)
|
|
187
|
+
except RuntimeError as e:
|
|
188
|
+
self.api_index -= 1
|
|
189
|
+
logger.warning(f"RuntimeError: {e}")
|
|
190
|
+
logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
|
|
191
|
+
return npu_out
|
|
186
192
|
|
|
187
193
|
if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]:
|
|
188
194
|
cpu_out = cpu_out.float()
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,16 +12,17 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
"""
|
|
17
15
|
|
|
18
16
|
import os
|
|
19
17
|
import time
|
|
20
|
-
import numpy as np
|
|
21
18
|
from collections import namedtuple
|
|
22
|
-
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files
|
|
23
23
|
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
24
24
|
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
25
|
-
from msprobe.
|
|
25
|
+
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class Compare:
|
|
@@ -126,7 +126,7 @@ class Compare:
|
|
|
126
126
|
all_close = np.allclose(data_left, data_right, atol=al, rtol=rl)
|
|
127
127
|
np.seterr(divide='raise')
|
|
128
128
|
cos_sim = np.dot(data_left, data_right) / (
|
|
129
|
-
|
|
129
|
+
np.sqrt(np.dot(data_left, data_left)) * np.sqrt(np.dot(data_right, data_right)))
|
|
130
130
|
err_cnt = 0
|
|
131
131
|
total_cnt = data_left.shape[0]
|
|
132
132
|
diff_table_columns = ['Index', 'Left', 'Right', 'Diff']
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,14 +12,13 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
"""
|
|
17
15
|
|
|
18
16
|
import os
|
|
17
|
+
|
|
19
18
|
import numpy as np
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class Const:
|
|
23
|
-
|
|
24
22
|
MS_ACCU_CMP_PATH = '/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py'
|
|
25
23
|
MS_ACCU_CMP_FILE_NAME = 'msaccucmp.py'
|
|
26
24
|
ROOT_DIR = ""
|
|
@@ -1,4 +1,18 @@
|
|
|
1
|
-
#
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import os
|
|
3
17
|
|
|
4
18
|
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,13 +12,14 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
17
|
-
import cmd
|
|
15
|
+
|
|
18
16
|
import argparse
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
import cmd
|
|
18
|
+
|
|
21
19
|
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
22
20
|
from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception
|
|
21
|
+
from msprobe.pytorch.parse_tool.lib.parse_tool import ParseTool
|
|
22
|
+
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class InteractiveCli(cmd.Cmd):
|
|
@@ -81,7 +81,7 @@ class InteractiveCli(cmd.Cmd):
|
|
|
81
81
|
self.util.check_files_in_path(args.my_dump_path)
|
|
82
82
|
self.util.check_files_in_path(args.golden_dump_path)
|
|
83
83
|
if self.util.dir_contains_only(args.my_dump_path, ".npy") and \
|
|
84
|
-
|
|
84
|
+
self.util.dir_contains_only(args.golden_dump_path, ".npy"):
|
|
85
85
|
self.parse_tool.do_compare_converted_dir(args)
|
|
86
86
|
else:
|
|
87
87
|
self.parse_tool.do_vector_compare(args)
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,13 +12,13 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import logging
|
|
17
|
+
|
|
18
18
|
from msprobe.core.common.exceptions import FileCheckException
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ParseException(Exception):
|
|
22
|
-
|
|
23
22
|
PARSE_INVALID_PATH_ERROR = 0
|
|
24
23
|
PARSE_NO_FILE_ERROR = 1
|
|
25
24
|
PARSE_NO_MODULE_ERROR = 2
|
|
@@ -51,4 +50,5 @@ def catch_exception(func):
|
|
|
51
50
|
except FileCheckException:
|
|
52
51
|
log.error("Command execution failed")
|
|
53
52
|
return result
|
|
53
|
+
|
|
54
54
|
return inner
|