mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from abc import abstractmethod
|
|
16
|
+
|
|
17
|
+
from mindspore import mint, ops
|
|
18
|
+
|
|
19
|
+
from msprobe.mindspore.common.log import logger
|
|
20
|
+
from msprobe.core.common.const import MonitorConst
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OptimizerMon(object):
|
|
24
|
+
def __init__(self, optim) -> None:
|
|
25
|
+
self.fp16_to_fp32_param = {}
|
|
26
|
+
self.optim = optim
|
|
27
|
+
self.state = {}
|
|
28
|
+
|
|
29
|
+
def narrow_from_flatten(self, param, flatten_state):
|
|
30
|
+
return flatten_state
|
|
31
|
+
|
|
32
|
+
def get_state(self, optim):
|
|
33
|
+
if hasattr(optim, 'chained_optimizers'):
|
|
34
|
+
for opt in optim.chained_optimizers:
|
|
35
|
+
self._get_single_state(opt)
|
|
36
|
+
else:
|
|
37
|
+
self._get_single_state(optim)
|
|
38
|
+
|
|
39
|
+
def fetch_grad(self, monitor, params2name):
|
|
40
|
+
if not self.fp16_to_fp32_param:
|
|
41
|
+
self.map_fp16_to_fp32_param(self.optim)
|
|
42
|
+
|
|
43
|
+
grad_dict = {}
|
|
44
|
+
first_param = True
|
|
45
|
+
for param, name in params2name.items():
|
|
46
|
+
if monitor.duplicate_param.get(name, False):
|
|
47
|
+
continue
|
|
48
|
+
if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
|
|
49
|
+
continue
|
|
50
|
+
grad = param.main_grad if monitor.params_have_main_grad else param.grad
|
|
51
|
+
element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
|
|
52
|
+
if param.numel() != element_in_cur_partition:
|
|
53
|
+
if first_param:
|
|
54
|
+
grad = grad.flatten()[-element_in_cur_partition:]
|
|
55
|
+
else: # supposed to be the last one
|
|
56
|
+
grad = grad.flatten()[:element_in_cur_partition]
|
|
57
|
+
first_param = False
|
|
58
|
+
if grad is None:
|
|
59
|
+
continue
|
|
60
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
61
|
+
monitor.register_param_call_id("hook_optimizer", tag)
|
|
62
|
+
grad_dict[tag] = grad
|
|
63
|
+
return grad_dict
|
|
64
|
+
|
|
65
|
+
def map_fp16_to_fp32_param(self, optim):
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
def fetch_mv(self, monitor, params2name):
|
|
69
|
+
if not self.fp16_to_fp32_param:
|
|
70
|
+
self.map_fp16_to_fp32_param(self.optim)
|
|
71
|
+
if not self.state:
|
|
72
|
+
self.get_state(self.optim)
|
|
73
|
+
|
|
74
|
+
exp_avg_dict = {}
|
|
75
|
+
exp_avg_sq_dict = {}
|
|
76
|
+
update_dict = {}
|
|
77
|
+
ratio_dict = {}
|
|
78
|
+
|
|
79
|
+
if not self.state:
|
|
80
|
+
logger.warning('optimizer state can not accessed')
|
|
81
|
+
return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict
|
|
82
|
+
|
|
83
|
+
for lp_param, name in params2name.items():
|
|
84
|
+
if lp_param in self.fp16_to_fp32_param:
|
|
85
|
+
hp_param = self.fp16_to_fp32_param[lp_param]
|
|
86
|
+
else:
|
|
87
|
+
hp_param = lp_param
|
|
88
|
+
|
|
89
|
+
if hp_param in self.state:
|
|
90
|
+
state_param = self.state.get(hp_param, {})
|
|
91
|
+
exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None))
|
|
92
|
+
exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None))
|
|
93
|
+
if monitor.mv_distribution:
|
|
94
|
+
exp_avg_dict[name] = exp_avg
|
|
95
|
+
exp_avg_sq_dict[name] = exp_avg_sq
|
|
96
|
+
if monitor.mg_direction:
|
|
97
|
+
exp_avg_dict[name] = exp_avg
|
|
98
|
+
if monitor.ur_distribution:
|
|
99
|
+
if len(self.optim.param_groups) > 1:
|
|
100
|
+
logger.info(f"the length of optim.param_groups is {len(self.optim.param_groups)}.")
|
|
101
|
+
if 'step' in state_param:
|
|
102
|
+
step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
103
|
+
elif 'step' in self.optim.param_groups[0]:
|
|
104
|
+
step = self.optim.param_groups[0]['step'] # AdamW from mindspeed
|
|
105
|
+
else:
|
|
106
|
+
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
107
|
+
continue
|
|
108
|
+
if exp_avg is None or exp_avg_sq is None:
|
|
109
|
+
logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.")
|
|
110
|
+
continue
|
|
111
|
+
exp_avg_hat = exp_avg / (1 - self.optim.defaults['betas'][0] ** step)
|
|
112
|
+
exp_avg_sq_hat = exp_avg_sq / (1 - self.optim.defaults['betas'][1] ** step)
|
|
113
|
+
update_dict[name] = exp_avg_hat / (mint.sqrt(exp_avg_sq_hat) + self.optim.defaults['eps'])
|
|
114
|
+
ratio_dict[name] = exp_avg_hat / mint.sqrt(exp_avg_sq_hat)
|
|
115
|
+
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
116
|
+
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
117
|
+
return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict
|
|
118
|
+
|
|
119
|
+
def _get_single_state(self, optim):
|
|
120
|
+
state = {}
|
|
121
|
+
if hasattr(optim, 'param_to_cpu_states_map'):
|
|
122
|
+
state = optim.param_to_cpu_states_map
|
|
123
|
+
elif hasattr(optim, 'state'):
|
|
124
|
+
state = optim.state
|
|
125
|
+
elif hasattr(optim, 'optimizer') and hasattr(optim.optimizer, 'state'):
|
|
126
|
+
state = optim.optimizer.state
|
|
127
|
+
self.state.update(state)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class MixPrecisionOptimizerMon(OptimizerMon):
|
|
131
|
+
"""
|
|
132
|
+
混合精度优化器监控类。在混合精度训练中监控和管理优化器。
|
|
133
|
+
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
134
|
+
"""
|
|
135
|
+
def map_fp16_to_fp32_param(self, optim):
|
|
136
|
+
for fp16_group, fp32_group in zip(optim.float16_groups, optim.fp32_from_float16_groups):
|
|
137
|
+
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
138
|
+
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
142
|
+
def map_fp16_to_fp32_param(self, optim):
|
|
143
|
+
if not (hasattr(optim, "model_float16_groups") and
|
|
144
|
+
hasattr(optim, "shard_fp32_from_float16_groups")):
|
|
145
|
+
raise Exception(
|
|
146
|
+
"megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
|
|
147
|
+
"if not, please check megatron-lm version")
|
|
148
|
+
for fp16_group, shard_fp32_group in zip(optim.model_float16_groups,
|
|
149
|
+
optim.shard_fp32_from_float16_groups):
|
|
150
|
+
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
151
|
+
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
155
|
+
def map_fp16_to_fp32_param(self, optim):
|
|
156
|
+
for opt in optim.chained_optimizers:
|
|
157
|
+
super().map_fp16_to_fp32_param(opt)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
161
|
+
def map_fp16_to_fp32_param(self, optim):
|
|
162
|
+
for opt in optim.chained_optimizers:
|
|
163
|
+
super().map_fp16_to_fp32_param(opt)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class DeepSpeedZeroOptimizerMon(OptimizerMon):
|
|
167
|
+
"""
|
|
168
|
+
Base monitor class for DeepSpeed ZeRO optimizer.
|
|
169
|
+
ZeRO stage 0 no partition
|
|
170
|
+
ZeRO stage 1 partitions optimizer states across data parallel processes.
|
|
171
|
+
ZeRO stage 2 additionally partitions gradients.
|
|
172
|
+
ZeRO stage 3 additionally partitions parameters.
|
|
173
|
+
|
|
174
|
+
This class provides monitoring capabilities for ZeRO optimizers by:
|
|
175
|
+
- Handling gradient collection for different ZeRO stages
|
|
176
|
+
- Managing optimizer state access for monitoring
|
|
177
|
+
"""
|
|
178
|
+
def __init__(self, optim):
|
|
179
|
+
super().__init__(optim)
|
|
180
|
+
self.stage = ''
|
|
181
|
+
self.bit16_groups = []
|
|
182
|
+
self.fp32_flat_groups = []
|
|
183
|
+
self.param2group = ()
|
|
184
|
+
self.param2index = []
|
|
185
|
+
self.group_offset = {}
|
|
186
|
+
|
|
187
|
+
@abstractmethod
|
|
188
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
189
|
+
raise NotImplementedError
|
|
190
|
+
|
|
191
|
+
def param_not_in_partition(self, lp_param, group_idx):
|
|
192
|
+
param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx]
|
|
193
|
+
hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param))
|
|
194
|
+
return hp_address is None
|
|
195
|
+
|
|
196
|
+
def get_position(self, lp_param, group_idx):
|
|
197
|
+
param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx]
|
|
198
|
+
hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param))
|
|
199
|
+
return hp_address.start, hp_address.numel
|
|
200
|
+
|
|
201
|
+
def get_group_index(self):
|
|
202
|
+
param2group = {}
|
|
203
|
+
for group_idx, bit16_group in enumerate(self.bit16_groups):
|
|
204
|
+
for param in bit16_group:
|
|
205
|
+
param2group[param] = group_idx
|
|
206
|
+
return param2group
|
|
207
|
+
|
|
208
|
+
def get_param_index(self, lp_param, group_idx):
|
|
209
|
+
if not self.param2index:
|
|
210
|
+
for group in self.bit16_groups:
|
|
211
|
+
param2index = {}
|
|
212
|
+
for index, param in enumerate(group):
|
|
213
|
+
param2index[param] = index
|
|
214
|
+
self.param2index.append(param2index)
|
|
215
|
+
|
|
216
|
+
return self.param2index[group_idx][lp_param]
|
|
217
|
+
|
|
218
|
+
def narrow_from_flatten(self, param, flatten_state):
|
|
219
|
+
if flatten_state is None:
|
|
220
|
+
return flatten_state
|
|
221
|
+
group_idx = self.param2group[param]
|
|
222
|
+
if self.param_not_in_partition(param, group_idx):
|
|
223
|
+
return None
|
|
224
|
+
start, numel = self.get_position(param, group_idx)
|
|
225
|
+
return flatten_state.narrow(0, start, numel)
|
|
226
|
+
|
|
227
|
+
def map_fp16_to_fp32_param(self, optim):
|
|
228
|
+
for group_idx, group in enumerate(self.bit16_groups):
|
|
229
|
+
for param in group:
|
|
230
|
+
self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx]
|
|
231
|
+
|
|
232
|
+
def fetch_grad(self, monitor, params2name):
|
|
233
|
+
grad_dict = {}
|
|
234
|
+
for lp_param, name in params2name.items():
|
|
235
|
+
group_idx = self.param2group[lp_param]
|
|
236
|
+
param_id = self.get_param_index(lp_param, group_idx)
|
|
237
|
+
if self.param_not_in_partition(lp_param, group_idx):
|
|
238
|
+
continue
|
|
239
|
+
if self.stage == '1or2':
|
|
240
|
+
param_id = param_id - self.group_offset[group_idx] - 1
|
|
241
|
+
grad = self.get_grad_for_param(lp_param, group_idx, param_id)
|
|
242
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
243
|
+
monitor.register_param_call_id("hook_optimizer", tag)
|
|
244
|
+
grad_dict[tag] = grad
|
|
245
|
+
|
|
246
|
+
return grad_dict
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
|
|
250
|
+
def __init__(self, optim):
|
|
251
|
+
super().__init__(optim)
|
|
252
|
+
self.stage = '0'
|
|
253
|
+
self.bit16_groups = optim.bf16_groups
|
|
254
|
+
self.fp32_flat_groups = optim.fp32_groups_flat_partition
|
|
255
|
+
self.param2group = self.get_group_index()
|
|
256
|
+
|
|
257
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
258
|
+
return self.optim.fp32_groups_gradient_dict[group_idx][param_id]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
|
|
262
|
+
def __init__(self, optim):
|
|
263
|
+
super().__init__(optim)
|
|
264
|
+
self.stage = '1or2'
|
|
265
|
+
self.bit16_groups = optim.bit16_groups
|
|
266
|
+
self.fp32_flat_groups = optim.single_partition_of_fp32_groups
|
|
267
|
+
self.param2group = self.get_group_index()
|
|
268
|
+
self.group_offset = {}
|
|
269
|
+
self.get_group_offset()
|
|
270
|
+
|
|
271
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
272
|
+
if getattr(self.optim, "cpu_offload", False):
|
|
273
|
+
grads = self.optim.single_partition_of_fp32_groups[group_idx].grad
|
|
274
|
+
start, numel = self.get_position(lp_param, group_idx)
|
|
275
|
+
grad = grads.narrow(0, start, numel)
|
|
276
|
+
else:
|
|
277
|
+
grad = self.optim.averaged_gradients[group_idx][param_id]
|
|
278
|
+
return grad
|
|
279
|
+
|
|
280
|
+
def get_group_offset(self):
|
|
281
|
+
for group_idx, group in enumerate(self.bit16_groups):
|
|
282
|
+
self.group_offset[group_idx] = -1
|
|
283
|
+
for lp_param in group:
|
|
284
|
+
if self.param_not_in_partition(lp_param, group_idx):
|
|
285
|
+
self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx)
|
|
286
|
+
else:
|
|
287
|
+
break
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
|
|
291
|
+
def __init__(self, optim):
|
|
292
|
+
super().__init__(optim)
|
|
293
|
+
self.stage = '3'
|
|
294
|
+
self.bit16_groups = optim.fp16_groups
|
|
295
|
+
self.fp32_flat_groups = optim.fp32_partitioned_groups_flat
|
|
296
|
+
self.param2group = self.get_group_index()
|
|
297
|
+
|
|
298
|
+
def param_not_in_partition(self, lp_param, group_idx):
|
|
299
|
+
"""Each param partioned across all zero ranks"""
|
|
300
|
+
return False
|
|
301
|
+
|
|
302
|
+
def get_position(self, lp_param, group_idx):
|
|
303
|
+
param_id = self.optim.get_param_id(lp_param)
|
|
304
|
+
return self.optim.grad_position[param_id][1:]
|
|
305
|
+
|
|
306
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
307
|
+
return self.optim.averaged_gradients[group_idx][param_id]
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class OptimizerMonFactory:
|
|
311
|
+
_optimizer_mon_map = {
|
|
312
|
+
"FP32Optimizer": OptimizerMon,
|
|
313
|
+
"Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
|
|
314
|
+
"DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
315
|
+
"SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
316
|
+
"ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
317
|
+
"ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
318
|
+
"ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
|
|
319
|
+
"BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
|
|
320
|
+
"DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
|
|
321
|
+
"DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
|
|
322
|
+
"Adam": OptimizerMon
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
@staticmethod
|
|
326
|
+
def create_optimizer_mon(optimizer):
|
|
327
|
+
# auto replace opt_ty
|
|
328
|
+
optimizer_class = optimizer.__class__.__name__
|
|
329
|
+
if optimizer_class == "ChainedOptimizer":
|
|
330
|
+
optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
|
|
331
|
+
logger.info(f'The optimizer type is {optimizer_class}')
|
|
332
|
+
|
|
333
|
+
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon)
|
|
334
|
+
return optimizer_mon_class(optimizer)
|
|
@@ -24,18 +24,24 @@ from msprobe.core.common.log import logger
|
|
|
24
24
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def get_single_metrics(op_list, tag, tensor, output=None):
|
|
27
|
+
def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None):
|
|
28
28
|
if output is None:
|
|
29
29
|
output = {}
|
|
30
30
|
if tag not in output:
|
|
31
31
|
output[tag] = {}
|
|
32
32
|
for op in op_list:
|
|
33
33
|
func = FUNC_MAP.get(op)
|
|
34
|
-
|
|
34
|
+
if op == "zeros":
|
|
35
|
+
statistic = func(tensor, eps)
|
|
36
|
+
else:
|
|
37
|
+
statistic = func(tensor)
|
|
35
38
|
if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
|
|
36
39
|
statistic = float(statistic)
|
|
37
40
|
statistic = Tensor(statistic)
|
|
38
|
-
|
|
41
|
+
if isinstance(statistic, Tensor):
|
|
42
|
+
output[tag][op] = statistic.astype(mstype.float32)
|
|
43
|
+
else:
|
|
44
|
+
output[tag][op] = statistic
|
|
39
45
|
|
|
40
46
|
|
|
41
47
|
def get_metrics(op_list, tag2tensor, eps, output=None):
|
|
@@ -44,7 +50,7 @@ def get_metrics(op_list, tag2tensor, eps, output=None):
|
|
|
44
50
|
for tag, tensor in tag2tensor.items():
|
|
45
51
|
if tag not in output:
|
|
46
52
|
output[tag] = {}
|
|
47
|
-
get_single_metrics(op_list, tag, tensor, output)
|
|
53
|
+
get_single_metrics(op_list, tag, tensor, eps, output)
|
|
48
54
|
return output
|
|
49
55
|
|
|
50
56
|
|
|
@@ -91,6 +97,11 @@ def validate_ops(ops):
|
|
|
91
97
|
default_op = MonitorConst.OP_LIST[0]
|
|
92
98
|
valid_ops.append(default_op)
|
|
93
99
|
logger.info(f"There is no valid ops, default op {default_op} is used")
|
|
100
|
+
# 增加默认shape和dtype参数
|
|
101
|
+
if "shape" not in valid_ops:
|
|
102
|
+
valid_ops.append("shape")
|
|
103
|
+
if "dtype" not in valid_ops:
|
|
104
|
+
valid_ops.append("dtype")
|
|
94
105
|
return valid_ops
|
|
95
106
|
|
|
96
107
|
|
|
@@ -171,7 +182,7 @@ def validate_alert(alert):
|
|
|
171
182
|
args = rule.get("args")
|
|
172
183
|
if args and isinstance(args, dict):
|
|
173
184
|
threshold = args.get("threshold")
|
|
174
|
-
if not isinstance(threshold, float) or threshold < 0:
|
|
185
|
+
if not isinstance(threshold, (float, int)) or threshold < 0:
|
|
175
186
|
raise TypeError('threshold must be float and not less than 0')
|
|
176
187
|
dump = alert.get('dump')
|
|
177
188
|
if dump and not isinstance(dump, bool):
|
|
@@ -217,6 +228,13 @@ def validate_dynamic_on(dynamic_on):
|
|
|
217
228
|
raise TypeError('dynamic_on should be a bool')
|
|
218
229
|
|
|
219
230
|
|
|
231
|
+
def validate_monitor_mbs_grad(monitor_mbs_grad):
|
|
232
|
+
if not isinstance(monitor_mbs_grad, bool):
|
|
233
|
+
logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
|
|
234
|
+
return False
|
|
235
|
+
return monitor_mbs_grad
|
|
236
|
+
|
|
237
|
+
|
|
220
238
|
def validate_config(config):
|
|
221
239
|
config['ops'] = validate_ops(config.get('ops', []))
|
|
222
240
|
|
|
@@ -266,6 +284,8 @@ def validate_config(config):
|
|
|
266
284
|
collect_times = config.get('collect_times', int(1e8))
|
|
267
285
|
validate_collect_times(collect_times)
|
|
268
286
|
|
|
287
|
+
config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
|
|
288
|
+
|
|
269
289
|
dynamic_on = config.get('dynamic_on', False)
|
|
270
290
|
validate_dynamic_on(dynamic_on)
|
|
271
291
|
|
msprobe/mindspore/ms_config.py
CHANGED
|
@@ -29,6 +29,7 @@ class TensorConfig(BaseConfig):
|
|
|
29
29
|
self.check_mode = None
|
|
30
30
|
self.file_format = json_config.get("file_format")
|
|
31
31
|
self.check_config()
|
|
32
|
+
self._check_summary_mode()
|
|
32
33
|
self._check_config()
|
|
33
34
|
|
|
34
35
|
def _check_config(self):
|
|
@@ -42,12 +43,23 @@ class StatisticsConfig(BaseConfig):
|
|
|
42
43
|
self.file_format = None
|
|
43
44
|
self.check_mode = None
|
|
44
45
|
self.check_config()
|
|
45
|
-
self.
|
|
46
|
+
self._check_summary_mode()
|
|
46
47
|
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
self.tensor_list = json_config.get("tensor_list", [])
|
|
49
|
+
self._check_str_list_config(self.tensor_list, "tensor_list")
|
|
50
|
+
self.stat_cal_mode = json_config.get("device", "host")
|
|
51
|
+
self.device_stat_precision_mode = json_config.get("precision", "high")
|
|
52
|
+
self._check_stat_params()
|
|
53
|
+
|
|
54
|
+
def _check_stat_params(self):
|
|
55
|
+
if self.stat_cal_mode not in ["device", "host"]:
|
|
56
|
+
raise Exception("Config param [device] is invalid, expected from [\"device\", \"host\"]")
|
|
57
|
+
if self.device_stat_precision_mode not in ["high", "low"]:
|
|
58
|
+
raise Exception("Config param [precision] is invalid, expected from [\"high\", \"low\"]")
|
|
59
|
+
|
|
60
|
+
def _check_summary_mode(self):
|
|
49
61
|
muti_opt = ["md5", "max", "min", "mean", "l2norm"]
|
|
50
|
-
if isinstance(self.summary_mode, str) and self.summary_mode not in
|
|
62
|
+
if isinstance(self.summary_mode, str) and self.summary_mode not in Const.SUMMARY_MODE:
|
|
51
63
|
raise Exception("summary_mode is invalid")
|
|
52
64
|
if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode):
|
|
53
65
|
raise Exception("summary_mode is invalid")
|
|
@@ -132,14 +144,3 @@ def parse_task_config(task, json_config):
|
|
|
132
144
|
if task not in TaskDict:
|
|
133
145
|
raise Exception("task is invalid.")
|
|
134
146
|
return TaskDict.get(task)(task_map)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def parse_json_config(json_file_path):
|
|
138
|
-
if not json_file_path:
|
|
139
|
-
raise Exception("json file path is None")
|
|
140
|
-
json_config = load_json(json_file_path)
|
|
141
|
-
common_config = parse_common_config(json_config)
|
|
142
|
-
if not common_config.task:
|
|
143
|
-
common_config.task = Const.STATISTICS
|
|
144
|
-
task_config = parse_task_config(common_config.task, json_config)
|
|
145
|
-
return common_config, task_config
|
|
@@ -29,11 +29,14 @@ class TaskHandlerFactory:
|
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
@staticmethod
|
|
32
|
-
def create(config: DebuggerConfig):
|
|
32
|
+
def create(config: DebuggerConfig, model=None):
|
|
33
33
|
task = TaskHandlerFactory.tasks.get(config.task)
|
|
34
34
|
if not task:
|
|
35
35
|
raise Exception("Valid task is needed.")
|
|
36
|
-
|
|
36
|
+
if task == DumpToolFactory:
|
|
37
|
+
handler = task.create(config, model)
|
|
38
|
+
else:
|
|
39
|
+
handler = task.create(config)
|
|
37
40
|
if not handler:
|
|
38
41
|
raise Exception("Can not find task handler")
|
|
39
42
|
return handler
|
msprobe/msprobe.py
CHANGED
|
@@ -22,6 +22,8 @@ from msprobe.core.common.log import logger
|
|
|
22
22
|
from msprobe.core.compare.utils import _compare_parser
|
|
23
23
|
from msprobe.core.compare.compare_cli import compare_cli
|
|
24
24
|
from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
|
|
25
|
+
from msprobe.core.config_check.config_check_cli import _config_checking_parser, \
|
|
26
|
+
_run_config_checking_command
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def is_module_available(module_name):
|
|
@@ -51,6 +53,9 @@ def main():
|
|
|
51
53
|
graph_service_cmd_parser = subparsers.add_parser('graph')
|
|
52
54
|
op_generate_cmd_parser = subparsers.add_parser('op_generate')
|
|
53
55
|
merge_result_parser = subparsers.add_parser('merge_result')
|
|
56
|
+
config_checking_parser = subparsers.add_parser('config_check')
|
|
57
|
+
nan_analyze_parser = subparsers.add_parser('nan_analyze')
|
|
58
|
+
_config_checking_parser(config_checking_parser)
|
|
54
59
|
_compare_parser(compare_cmd_parser)
|
|
55
60
|
_merge_result_parser(merge_result_parser)
|
|
56
61
|
|
|
@@ -71,6 +76,7 @@ def main():
|
|
|
71
76
|
from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
|
|
72
77
|
from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
|
|
73
78
|
_run_operator_generate_commond
|
|
79
|
+
from msprobe.nan_analyze.analyzer import _nan_analyze_parser, _run_nan_analyze
|
|
74
80
|
|
|
75
81
|
_run_ut_parser(run_ut_cmd_parser)
|
|
76
82
|
_run_ut_parser(multi_run_ut_cmd_parser)
|
|
@@ -80,6 +86,7 @@ def main():
|
|
|
80
86
|
_run_overflow_check_parser(run_overflow_check_cmd_parser)
|
|
81
87
|
_pt_graph_service_parser(graph_service_cmd_parser)
|
|
82
88
|
_op_generator_parser(op_generate_cmd_parser)
|
|
89
|
+
_nan_analyze_parser(nan_analyze_parser)
|
|
83
90
|
elif framework_args.framework == Const.MS_FRAMEWORK:
|
|
84
91
|
from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
|
|
85
92
|
from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
|
|
@@ -91,6 +98,10 @@ def main():
|
|
|
91
98
|
|
|
92
99
|
_ms_graph_service_parser(graph_service_cmd_parser)
|
|
93
100
|
|
|
101
|
+
from msprobe.mindspore.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
|
|
102
|
+
_run_operator_generate_commond
|
|
103
|
+
_op_generator_parser(op_generate_cmd_parser)
|
|
104
|
+
|
|
94
105
|
args = parser.parse_args(sys.argv[1:])
|
|
95
106
|
if sys.argv[2] == Const.PT_FRAMEWORK:
|
|
96
107
|
if not is_torch_available:
|
|
@@ -118,6 +129,10 @@ def main():
|
|
|
118
129
|
compare_cli(args)
|
|
119
130
|
elif sys.argv[3] == "merge_result":
|
|
120
131
|
merge_result_cli(args)
|
|
132
|
+
elif sys.argv[3] == "config_check":
|
|
133
|
+
_run_config_checking_command(args)
|
|
134
|
+
elif sys.argv[3] == "nan_analyze":
|
|
135
|
+
_run_nan_analyze(args)
|
|
121
136
|
else:
|
|
122
137
|
if not is_module_available(Const.MS_FRAMEWORK):
|
|
123
138
|
logger.error("MindSpore does not exist, please install MindSpore library")
|
|
@@ -134,9 +149,13 @@ def main():
|
|
|
134
149
|
mul_api_checker_main(args)
|
|
135
150
|
elif sys.argv[3] == "graph":
|
|
136
151
|
_ms_graph_service_command(args)
|
|
152
|
+
elif sys.argv[3] == 'op_generate':
|
|
153
|
+
_run_operator_generate_commond(args)
|
|
137
154
|
elif sys.argv[3] == "code_mapping":
|
|
138
155
|
from msprobe.mindspore.code_mapping.main import code_mapping_main
|
|
139
156
|
code_mapping_main(args)
|
|
157
|
+
elif sys.argv[3] == "config_check":
|
|
158
|
+
_run_config_checking_command(args)
|
|
140
159
|
|
|
141
160
|
|
|
142
161
|
if __name__ == "__main__":
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|