mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from typing import Dict
|
|
20
|
+
import numpy as np
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.file_utils import FileOpen, load_yaml
|
|
25
|
+
from msprobe.core.common.framework_adapter import FmkAdp
|
|
26
|
+
|
|
27
|
+
# both weights and bias are partitioned in column parallel
|
|
28
|
+
COLUMN_PARALLEL_PARAMS = ['linear_qkv', 'linear_fc1', 'word_embeddings.weight', 'output_layer.weight']
|
|
29
|
+
# only weights are partitioned in column parallel
|
|
30
|
+
ROW_PARALLEL_PARAMS = ['linear_fc2.weight', 'linear_proj.weight']
|
|
31
|
+
ARGS = 'args'
|
|
32
|
+
LAYER_IDX_PATTERN = re.compile('layers\.(\d+)\.')
|
|
33
|
+
EXPERT_IDX_PATTERN = re.compile('experts\.(\d+)\.')
|
|
34
|
+
ITER_DIR_PATTERN = re.compile('iter_([\d]{7})')
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@recursion_depth_decorator('')
|
|
38
|
+
def _get_parameter(weights, prefix=''):
|
|
39
|
+
for k, v in weights.items():
|
|
40
|
+
name = Const.SEP.join([prefix, k]).strip(Const.SEP)
|
|
41
|
+
if isinstance(v, dict):
|
|
42
|
+
yield from _get_parameter(v, prefix=name)
|
|
43
|
+
elif FmkAdp.is_tensor(v):
|
|
44
|
+
yield name, FmkAdp.asnumpy(v)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _map_to_mcore_local_names(param_name: str) -> str:
|
|
48
|
+
"""Map parameter names to mcore + local transformer implementation names."""
|
|
49
|
+
mcore_local_map = load_yaml(os.path.join(os.path.dirname(__file__), 'name_mapping.yaml'))
|
|
50
|
+
for other_name, mcore_local_name in mcore_local_map.items():
|
|
51
|
+
param_name = param_name.replace(other_name, mcore_local_name)
|
|
52
|
+
|
|
53
|
+
return param_name
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _parse_real_layer_idx(param_name, num_layers_per_stage, pp_size, pp_rank):
|
|
57
|
+
"""Map local (virtual) pipeline stage layer index to global layer index.
|
|
58
|
+
|
|
59
|
+
For virtual pipeline parallel, each pipeline stage is further divided into virtual stages.
|
|
60
|
+
The global layer index needs to account for both pipeline stage and virtual stage.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
param_name (str): Parameter name containing layer index: layers.x.<submodule_name>/<vpp_stage>
|
|
64
|
+
num_layers_per_stage (int): Number of layers per pipeline stage
|
|
65
|
+
pp_size (int): Pipeline parallel size
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
int: Global layer index accounting for both pipeline and virtual pipeline stages
|
|
69
|
+
"""
|
|
70
|
+
# Extract local layer index from parameter name
|
|
71
|
+
layer_match = re.search(LAYER_IDX_PATTERN, param_name)
|
|
72
|
+
param_name, vpp_stage = param_name.split(Const.SCOPE_SEPARATOR)
|
|
73
|
+
if not layer_match:
|
|
74
|
+
return param_name
|
|
75
|
+
|
|
76
|
+
local_layer_idx = int(layer_match.group(1))
|
|
77
|
+
vpp_stage = int(vpp_stage)
|
|
78
|
+
|
|
79
|
+
# Calculate global layer index based on pipeline stage and virtual stage
|
|
80
|
+
real_layer_idx = local_layer_idx + (pp_size * vpp_stage + pp_rank) * num_layers_per_stage
|
|
81
|
+
|
|
82
|
+
return param_name.replace(f'layers.{local_layer_idx}', f'layers.{real_layer_idx}')
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _parse_real_expert_idx(param_name, num_experts_per_rank, exp_rank):
|
|
86
|
+
"""Map local expert index to global expert index. TODO: shared expert
|
|
87
|
+
|
|
88
|
+
For expert parallel, experts are distributed across ranks. This function maps
|
|
89
|
+
the local expert index on a rank to its global index across all ranks.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
param_name (str): Parameter name containing local expert index
|
|
93
|
+
num_experts_per_rank (int): Number of experts on each rank
|
|
94
|
+
exp_rank (int): Expert parallel rank
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
str: Parameter name with local expert index replaced by global expert index
|
|
98
|
+
"""
|
|
99
|
+
# Extract local layer index from parameter name
|
|
100
|
+
expert_match = re.search(EXPERT_IDX_PATTERN, param_name)
|
|
101
|
+
if not expert_match:
|
|
102
|
+
return param_name
|
|
103
|
+
|
|
104
|
+
local_expert_idx = int(expert_match.group(1))
|
|
105
|
+
# Calculate global layer index based on pipeline stage and virtual stage
|
|
106
|
+
real_experts_idx = local_expert_idx + exp_rank * num_experts_per_rank
|
|
107
|
+
|
|
108
|
+
return param_name.replace(f'experts.{local_expert_idx}', f'experts.{real_experts_idx}')
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _consolidate_tp_weights(weights: Dict) -> Dict:
|
|
112
|
+
"""Consolidate weights from different tensor parallel ranks into combined tensors.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
weights: Dictionary of weights with rank information in keys
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Dict: Consolidated weights without rank information
|
|
119
|
+
"""
|
|
120
|
+
consolidated = {}
|
|
121
|
+
for key, tensors in weights.items():
|
|
122
|
+
if any([name in key for name in COLUMN_PARALLEL_PARAMS]):
|
|
123
|
+
# Column parallel - concatenate along input dimension (dim 0)
|
|
124
|
+
combined = np.concatenate(tensors, axis=0)
|
|
125
|
+
elif any([name in key for name in ROW_PARALLEL_PARAMS]):
|
|
126
|
+
# Row parallel - concatenate along output dimension (dim 1)
|
|
127
|
+
combined = np.concatenate(tensors, axis=1)
|
|
128
|
+
else:
|
|
129
|
+
# For other params, verify identical and use first
|
|
130
|
+
if not all(np.allclose(tensors[0], t) for t in tensors[1:]):
|
|
131
|
+
logger.warning(f"Inconsistent values for {key} across TP ranks")
|
|
132
|
+
combined = tensors[0]
|
|
133
|
+
|
|
134
|
+
consolidated[key] = combined
|
|
135
|
+
return consolidated
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _parse_num_layers_per_stage(tp_partition):
|
|
139
|
+
match = [re.findall(LAYER_IDX_PATTERN, key) for key in tp_partition.keys()]
|
|
140
|
+
layer_idx = [int(i[0]) for i in match if i]
|
|
141
|
+
num_layers_per_pipeline_stage = max(layer_idx) + 1
|
|
142
|
+
|
|
143
|
+
return num_layers_per_pipeline_stage
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def parse_parallel_size(checkpoint_dir: str):
|
|
147
|
+
"""Parse tensor, pipeline and expert parallel sizes from checkpoint filenames.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
checkpoint_dir (str): Directory containing checkpoint files
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Namespace
|
|
154
|
+
"""
|
|
155
|
+
# Find all rank directories
|
|
156
|
+
rank_dirs = [d for d in os.listdir(checkpoint_dir) if d.startswith('mp_rank_')]
|
|
157
|
+
|
|
158
|
+
if not rank_dirs:
|
|
159
|
+
raise ValueError(f"No checkpoint rank directories found in {checkpoint_dir}")
|
|
160
|
+
|
|
161
|
+
ckpt = FmkAdp.load_checkpoint(
|
|
162
|
+
os.path.join(checkpoint_dir, rank_dirs[0], 'model_optim_rng.pt'),
|
|
163
|
+
to_cpu=True,
|
|
164
|
+
weights_only=False)
|
|
165
|
+
args = ckpt[ARGS]
|
|
166
|
+
return (
|
|
167
|
+
args.tensor_model_parallel_size,
|
|
168
|
+
args.pipeline_model_parallel_size,
|
|
169
|
+
args.expert_model_parallel_size,
|
|
170
|
+
args.num_experts
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def parse_iteration(checkpoint_path: str) -> Dict:
|
|
175
|
+
"""
|
|
176
|
+
Parse the checkpoint iteration directory from a given checkpoint path.
|
|
177
|
+
|
|
178
|
+
If the path is a top-level checkpoint directory, this function reads the
|
|
179
|
+
'latest_checkpointed_iteration.txt' file to determine the latest iteration.
|
|
180
|
+
If the path is already an iteration directory (e.g., 'iter_0000005'), it extracts
|
|
181
|
+
the iteration number from the path.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
checkpoint_path (str): Path to the checkpoint directory or iteration directory.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
str: The full path to the checkpoint directory for the determined iteration.
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
ValueError: If the checkpoint directory for the determined iteration does not exist.
|
|
191
|
+
"""
|
|
192
|
+
iteration = None
|
|
193
|
+
tracker_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt")
|
|
194
|
+
if os.path.exists(tracker_file):
|
|
195
|
+
with FileOpen(tracker_file, 'r') as f:
|
|
196
|
+
latest_iteration = f.read().strip()
|
|
197
|
+
if latest_iteration != 'release':
|
|
198
|
+
try:
|
|
199
|
+
iteration = int(latest_iteration)
|
|
200
|
+
except Exception:
|
|
201
|
+
logger.warning(
|
|
202
|
+
f"The latest_checkpointed_iteration is supposed to be `release` or an int. \
|
|
203
|
+
But {latest_iteration} is found."
|
|
204
|
+
)
|
|
205
|
+
checkpoint_path = os.path.join(checkpoint_path, f'iter_{iteration:07d}')
|
|
206
|
+
else:
|
|
207
|
+
match = re.findall(ITER_DIR_PATTERN, checkpoint_path)
|
|
208
|
+
if match:
|
|
209
|
+
iteration = int(match[0])
|
|
210
|
+
|
|
211
|
+
# Checkpoint directory for this iteration
|
|
212
|
+
logger.info(f"Loaded checkpoint from iteration {iteration}")
|
|
213
|
+
|
|
214
|
+
if not os.path.exists(checkpoint_path):
|
|
215
|
+
raise ValueError(f"Checkpoint directory not found: {checkpoint_path}")
|
|
216
|
+
|
|
217
|
+
return checkpoint_path
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def get_weights_from_state_dict(state_dict):
|
|
221
|
+
weights = {}
|
|
222
|
+
vpp_stage = 0
|
|
223
|
+
if 'model' in state_dict:
|
|
224
|
+
model_weights = state_dict['model']
|
|
225
|
+
|
|
226
|
+
for key, value in _get_parameter(model_weights):
|
|
227
|
+
key = _map_to_mcore_local_names(key)
|
|
228
|
+
weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value
|
|
229
|
+
|
|
230
|
+
elif 'model0' in state_dict:
|
|
231
|
+
#vpp enabled
|
|
232
|
+
while f'model{vpp_stage}' in state_dict:
|
|
233
|
+
model_weights = state_dict[f'model{vpp_stage}']
|
|
234
|
+
for key, value in _get_parameter(model_weights):
|
|
235
|
+
key = _map_to_mcore_local_names(key)
|
|
236
|
+
weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value
|
|
237
|
+
vpp_stage += 1
|
|
238
|
+
return weights
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def load_megatron_weights(checkpoint_path: str) -> Dict:
|
|
242
|
+
"""Load Megatron parallel checkpoint weights into a single dictionary.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
checkpoint_path (str): Base checkpoint directory path
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
combined_weights: Dict with weights from all ranks, keys include rank info
|
|
249
|
+
"""
|
|
250
|
+
try:
|
|
251
|
+
import megatron
|
|
252
|
+
except ModuleNotFoundError as e:
|
|
253
|
+
raise ModuleNotFoundError("No module named 'megatron', which is required to load a megatron ckpt") from e
|
|
254
|
+
|
|
255
|
+
# Find latest iteration if not specified
|
|
256
|
+
checkpoint_path = parse_iteration(checkpoint_path)
|
|
257
|
+
|
|
258
|
+
# Parse parallel sizes from checkpoint directory structure
|
|
259
|
+
tp_size, pp_size, exp_size, num_experts = parse_parallel_size(checkpoint_path)
|
|
260
|
+
combined_weights = {}
|
|
261
|
+
|
|
262
|
+
# Load checkpoints from all ranks
|
|
263
|
+
for exp_rank in range(exp_size):
|
|
264
|
+
num_layers_per_pipeline_stage = 0
|
|
265
|
+
for pp_rank in range(pp_size):
|
|
266
|
+
tp_partition = defaultdict(list)
|
|
267
|
+
for tp_rank in range(tp_size):
|
|
268
|
+
# Construct checkpoint path based on parallel ranks
|
|
269
|
+
if pp_size > 1:
|
|
270
|
+
rank_dir = f'mp_rank_{tp_rank:02d}_{pp_rank:03d}'
|
|
271
|
+
else:
|
|
272
|
+
rank_dir = f'mp_rank_{tp_rank:02d}'
|
|
273
|
+
|
|
274
|
+
if exp_size > 1:
|
|
275
|
+
rank_dir = f'{rank_dir}_{exp_rank:03d}'
|
|
276
|
+
|
|
277
|
+
ckpt_file = os.path.join(checkpoint_path, rank_dir, 'model_optim_rng.pt')
|
|
278
|
+
try:
|
|
279
|
+
state_dict = FmkAdp.load_checkpoint(ckpt_file, to_cpu=True, weights_only=False)
|
|
280
|
+
partition = get_weights_from_state_dict(state_dict)
|
|
281
|
+
for key, weight in partition.items():
|
|
282
|
+
tp_partition[key].append(weight)
|
|
283
|
+
|
|
284
|
+
except Exception as load_error:
|
|
285
|
+
logger.warning(f"Error loading {ckpt_file}: {load_error}")
|
|
286
|
+
|
|
287
|
+
if not tp_partition:
|
|
288
|
+
raise ValueError('No state loaded.')
|
|
289
|
+
|
|
290
|
+
if not num_layers_per_pipeline_stage:
|
|
291
|
+
num_layers_per_pipeline_stage = _parse_num_layers_per_stage(tp_partition)
|
|
292
|
+
|
|
293
|
+
consolidated_weight = _consolidate_tp_weights(tp_partition)
|
|
294
|
+
for key, value in consolidated_weight.items():
|
|
295
|
+
key = _parse_real_layer_idx(key, num_layers_per_pipeline_stage, pp_size, pp_rank)
|
|
296
|
+
if num_experts:
|
|
297
|
+
key = _parse_real_expert_idx(key, num_experts // exp_size, exp_rank)
|
|
298
|
+
combined_weights[key] = value
|
|
299
|
+
|
|
300
|
+
logger.info(f"Found {len(combined_weights)} total parameters across all ranks")
|
|
301
|
+
|
|
302
|
+
return combined_weights
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
19
|
+
from msprobe.core.compare.npy_compare import CompareOps
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def in_different_shape(a, b):
|
|
24
|
+
if a.shape != b.shape:
|
|
25
|
+
logger.warning(f"a, b are in different shape. a: {a.shape}, b: {b.shape}")
|
|
26
|
+
return True
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def l2_distance(a, b):
|
|
31
|
+
if a is None or b is None:
|
|
32
|
+
return None
|
|
33
|
+
if in_different_shape(a, b):
|
|
34
|
+
return None
|
|
35
|
+
return np.linalg.norm(a - b).item()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def cos_sim(a, b):
|
|
39
|
+
if a is None or b is None:
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
if in_different_shape(a, b):
|
|
43
|
+
return None
|
|
44
|
+
if a.ndim > 0:
|
|
45
|
+
a = a.flatten().squeeze()
|
|
46
|
+
b = b.flatten().squeeze()
|
|
47
|
+
|
|
48
|
+
num = a.dot(b)
|
|
49
|
+
a_norm = np.linalg.norm(a)
|
|
50
|
+
b_norm = np.linalg.norm(b)
|
|
51
|
+
|
|
52
|
+
if a_norm == 0 and b_norm == 0:
|
|
53
|
+
return 1.
|
|
54
|
+
if a_norm == 0 or b_norm == 0:
|
|
55
|
+
logger.warning(f'One tensor norm is zero.')
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
sim = num / (a_norm * b_norm)
|
|
59
|
+
|
|
60
|
+
return sim.item()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def numel(a, b):
|
|
64
|
+
n1 = a.size
|
|
65
|
+
n2 = b.size
|
|
66
|
+
if n1 != n2:
|
|
67
|
+
logger.warning('parameters have different number of element')
|
|
68
|
+
return (n1, n2)
|
|
69
|
+
return n1
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def shape(a, b):
|
|
73
|
+
if in_different_shape(a, b):
|
|
74
|
+
return [list(a.shape), list(b.shape)]
|
|
75
|
+
return list(a.shape)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
METRIC_FUNC = {
|
|
79
|
+
'l2': l2_distance,
|
|
80
|
+
'cos': cos_sim,
|
|
81
|
+
'numel': numel,
|
|
82
|
+
'shape': shape
|
|
83
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
self_attention.linear_qkv.layer_norm_: input_layernorm.
|
|
2
|
+
language_model.: ''
|
|
3
|
+
encoder: decoder
|
|
4
|
+
.input_norm.: .input_layernorm.
|
|
5
|
+
query_key_value: linear_qkv
|
|
6
|
+
.dense.: .linear_proj.
|
|
7
|
+
post_attention_norm: pre_mlp_layernorm
|
|
8
|
+
dense_h_to_4h: linear_fc1
|
|
9
|
+
dense_4h_to_h: linear_fc2
|
|
10
|
+
mlp.local_experts: mlp.experts.local_experts
|
|
11
|
+
final_norm: final_layernorm
|
|
12
|
+
word_embeddings_for_head: output_layer
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.config_check.config_checker import ConfigChecker
|
|
17
|
+
from msprobe.core.config_check.ckpt_compare.ckpt_comparator import compare_checkpoints
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def pack(shell_path, output_path, framework):
|
|
22
|
+
ConfigChecker(shell_path=shell_path, output_zip_path=output_path, fmk=framework)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compare(bench_zip_path, cmp_zip_path, output_path, framework):
|
|
26
|
+
ConfigChecker.compare(bench_zip_path, cmp_zip_path, output_path, framework)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _config_checking_parser(parser):
|
|
30
|
+
parser.add_argument('-d', '--dump', nargs='*', help='Collect the train config into a zip file')
|
|
31
|
+
parser.add_argument('-c', '--compare', nargs=2, help='Compare two zip files or checkpoints')
|
|
32
|
+
parser.add_argument('-o', '--output', help='output path, default is current directory')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _run_config_checking_command(args):
|
|
36
|
+
if args.dump is not None:
|
|
37
|
+
output_dirpath = args.output if args.output else "./config_check_pack.zip"
|
|
38
|
+
pack(args.dump, output_dirpath, args.framework)
|
|
39
|
+
elif args.compare:
|
|
40
|
+
if args.compare[0].endswith('zip'):
|
|
41
|
+
logger.info('The input paths is zip files, comparing packed config.')
|
|
42
|
+
output_dirpath = args.output if args.output else "./config_check_result"
|
|
43
|
+
compare(args.compare[0], args.compare[1], output_dirpath, args.framework)
|
|
44
|
+
else:
|
|
45
|
+
logger.info('Comparing model checkpoint.')
|
|
46
|
+
output_dirpath = args.output if args.output else "./ckpt_similarity.json"
|
|
47
|
+
compare_checkpoints(args.compare[0], args.compare[1], output_dirpath)
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
logger.error("The param is not correct, you need to give '-d' for dump or '-c' for compare.")
|
|
51
|
+
raise Exception("The param is not correct, you need to give '-d' for dump or '-c' for compare.")
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import shutil
|
|
18
|
+
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
from msprobe.core.common.file_utils import save_excel, split_zip_file_path, \
|
|
22
|
+
create_directory, extract_zip
|
|
23
|
+
from msprobe.core.common.framework_adapter import FmkAdp
|
|
24
|
+
from msprobe.core.config_check.checkers.base_checker import PackInput
|
|
25
|
+
from msprobe.core.config_check.utils.utils import config_checking_print
|
|
26
|
+
from msprobe.core.common.const import Const
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ConfigChecker:
|
|
30
|
+
checkers = {}
|
|
31
|
+
pre_forward_fun_list = []
|
|
32
|
+
result_filename = "result.xlsx"
|
|
33
|
+
result_header = ["filename", "pass_check"]
|
|
34
|
+
step = 0
|
|
35
|
+
|
|
36
|
+
def __init__(self, model=None, shell_path=None, output_zip_path="./config_check_pack.zip", fmk="pytorch"):
|
|
37
|
+
FmkAdp.set_fmk(fmk)
|
|
38
|
+
self.pack_input = PackInput(output_zip_path, model, shell_path)
|
|
39
|
+
file_path, file_name = split_zip_file_path(self.pack_input.output_zip_path)
|
|
40
|
+
if not os.path.exists(file_path):
|
|
41
|
+
create_directory(file_path)
|
|
42
|
+
self.pack()
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def compare(bench_zip_path, cmp_zip_path, output_path, fmk=Const.PT_FRAMEWORK):
|
|
46
|
+
if os.path.exists(output_path):
|
|
47
|
+
shutil.rmtree(output_path)
|
|
48
|
+
bench_dir = os.path.join(output_path, "bench")
|
|
49
|
+
cmp_dir = os.path.join(output_path, "cmp")
|
|
50
|
+
extract_zip(bench_zip_path, bench_dir)
|
|
51
|
+
config_checking_print(f"extract zip file {bench_zip_path} to {bench_dir}")
|
|
52
|
+
extract_zip(cmp_zip_path, cmp_dir)
|
|
53
|
+
config_checking_print(f"extract zip file {cmp_zip_path} to {cmp_dir}")
|
|
54
|
+
|
|
55
|
+
result = []
|
|
56
|
+
summary_result = []
|
|
57
|
+
for checker in ConfigChecker.checkers.values():
|
|
58
|
+
checker_name, pass_check, df = checker.compare_ex(bench_dir, cmp_dir, output_path, fmk)
|
|
59
|
+
if checker_name:
|
|
60
|
+
summary_result.append([checker_name, pass_check])
|
|
61
|
+
if df is not None:
|
|
62
|
+
result.append((df, checker_name))
|
|
63
|
+
summary_result_df = pd.DataFrame(summary_result, columns=ConfigChecker.result_header)
|
|
64
|
+
result.insert(0, (summary_result_df, "summary"))
|
|
65
|
+
save_excel(os.path.join(output_path, ConfigChecker.result_filename), result)
|
|
66
|
+
config_checking_print(f"config checking result save to {os.path.realpath(output_path)}")
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def apply_patches(fmk=Const.PT_FRAMEWORK):
|
|
70
|
+
for checker in ConfigChecker.checkers.values():
|
|
71
|
+
checker.apply_patches(fmk)
|
|
72
|
+
|
|
73
|
+
def pack(self):
|
|
74
|
+
config_checking_print(f"pack result zip path {os.path.realpath(self.pack_input.output_zip_path)}")
|
|
75
|
+
|
|
76
|
+
def hook(model, args, kwargs):
|
|
77
|
+
for collect_func in self.pre_forward_fun_list:
|
|
78
|
+
collect_func(model, args, kwargs, ConfigChecker.step)
|
|
79
|
+
ConfigChecker.step += 1
|
|
80
|
+
|
|
81
|
+
if self.pack_input.model:
|
|
82
|
+
FmkAdp.register_forward_pre_hook(self.pack_input.model, hook, with_kwargs=True)
|
|
83
|
+
for checker in ConfigChecker.checkers.values():
|
|
84
|
+
if checker.input_needed and not getattr(self.pack_input, checker.input_needed):
|
|
85
|
+
continue
|
|
86
|
+
if FmkAdp.is_initialized() and FmkAdp.get_rank() != 0 and not checker.multi_rank:
|
|
87
|
+
continue
|
|
88
|
+
checker.pack(self.pack_input)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def register_checker_item(key, cls=None):
|
|
92
|
+
if cls is None:
|
|
93
|
+
# 无参数时,返回装饰器函数
|
|
94
|
+
return lambda cls: register_checker_item(key, cls)
|
|
95
|
+
ConfigChecker.checkers[key] = cls
|
|
96
|
+
return cls
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def register_pre_forward_fun_list(func):
|
|
100
|
+
ConfigChecker.pre_forward_fun_list.append(func)
|
|
@@ -13,7 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
16
|
+
dependency:
|
|
17
|
+
- transformers
|
|
18
|
+
- deepspeed
|
|
19
|
+
- megatron
|
|
20
|
+
- numpy
|
|
21
|
+
- datasets
|
|
22
|
+
- peft
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
HCCL_DETERMINISTIC:
|
|
17
|
+
npu:
|
|
18
|
+
name: HCCL_DETERMINISTIC
|
|
19
|
+
default_value: False
|
|
20
|
+
gpu:
|
|
21
|
+
name: NCCL_DETERMINISTIC
|
|
22
|
+
default_value: False
|
|
23
|
+
|
|
24
|
+
HCCL_ALGO:
|
|
25
|
+
npu:
|
|
26
|
+
name: HCCL_ALGO
|
|
27
|
+
default_value: None
|
|
28
|
+
gpu:
|
|
29
|
+
name: NCCL_ALGO
|
|
30
|
+
default_value: None
|
|
31
|
+
|
|
32
|
+
HCCL_INTRA_ROCE_ENABLE:
|
|
33
|
+
npu:
|
|
34
|
+
name: HCCL_INTRA_ROCE_ENABLE
|
|
35
|
+
default_value: 0
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
HCCL_INTRA_PICE_ENABLE:
|
|
39
|
+
npu:
|
|
40
|
+
name: HCCL_INTRA_ROCE_ENABLE
|
|
41
|
+
default_value: 1
|
|
42
|
+
|
|
43
|
+
ASCEND_LAUNCH_BLOCKING:
|
|
44
|
+
npu:
|
|
45
|
+
name: ASCEND_LAUNCH_BLOCKING
|
|
46
|
+
default_value: 0
|
|
47
|
+
gpu:
|
|
48
|
+
name: CUDA_LAUNCH_BLOCKING
|
|
49
|
+
default_value: 0
|
|
50
|
+
|
|
51
|
+
ASCEND_RT_VISIBLE_DEVICES:
|
|
52
|
+
npu:
|
|
53
|
+
name: ASCEND_RT_VISIBLE_DEVICES
|
|
54
|
+
default_value: None
|
|
55
|
+
gpu:
|
|
56
|
+
name: CUDA_VISIBLE_DEVICES
|
|
57
|
+
default_value: None
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
learning_rate:
|
|
2
|
+
- lr
|
|
3
|
+
- learningrate
|
|
4
|
+
|
|
5
|
+
batch_size:
|
|
6
|
+
- batch
|
|
7
|
+
- bs
|
|
8
|
+
- batch_size_per_gpu
|
|
9
|
+
|
|
10
|
+
epochs:
|
|
11
|
+
- num_epochs
|
|
12
|
+
- max_epochs
|
|
13
|
+
- epoch
|
|
14
|
+
|
|
15
|
+
weight_decay:
|
|
16
|
+
- wd
|
|
17
|
+
- weightdecay
|
|
18
|
+
|
|
19
|
+
dropout_rate:
|
|
20
|
+
- dropout
|
|
21
|
+
- drop_rate
|