mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from functools import wraps
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MegatronStepInfo:
|
|
20
|
+
is_megatron = False
|
|
21
|
+
is_forward = False
|
|
22
|
+
is_backward = False
|
|
23
|
+
forward_micro_step = -1
|
|
24
|
+
backward_micro_step = -1
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def reset(cls):
|
|
28
|
+
"""重置所有类属性到初始状态"""
|
|
29
|
+
cls.is_megatron = False
|
|
30
|
+
cls.is_forward = False
|
|
31
|
+
cls.is_backward = False
|
|
32
|
+
cls.forward_micro_step = -1
|
|
33
|
+
cls.backward_micro_step = -1
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def wrap_megatron_step(func, is_forward=True):
|
|
37
|
+
@wraps(func)
|
|
38
|
+
def wrapped_func(*args, **kwargs):
|
|
39
|
+
if not MegatronStepInfo.is_megatron:
|
|
40
|
+
MegatronStepInfo.is_megatron = True
|
|
41
|
+
if is_forward:
|
|
42
|
+
MegatronStepInfo.is_forward = True
|
|
43
|
+
MegatronStepInfo.is_backward = False
|
|
44
|
+
MegatronStepInfo.forward_micro_step += 1
|
|
45
|
+
else:
|
|
46
|
+
MegatronStepInfo.is_forward = False
|
|
47
|
+
MegatronStepInfo.is_backward = True
|
|
48
|
+
MegatronStepInfo.backward_micro_step += 1
|
|
49
|
+
return func(*args, **kwargs)
|
|
50
|
+
|
|
51
|
+
return wrapped_func
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_micro_step():
|
|
55
|
+
return MegatronStepInfo.forward_micro_step if MegatronStepInfo.is_forward else MegatronStepInfo.backward_micro_step
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def is_megatron():
|
|
59
|
+
return MegatronStepInfo.is_megatron
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import List
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RankGroupGenerator(object):
|
|
23
|
+
def __init__(self, tensor_parallel: int, expert_parallel: int, data_parallel: int,
|
|
24
|
+
pipeline_parallel: int, context_parallel: int, order: str) -> None:
|
|
25
|
+
self.tensor_parallel = tensor_parallel
|
|
26
|
+
self.expert_parallel = expert_parallel
|
|
27
|
+
self.data_parallel = data_parallel
|
|
28
|
+
self.pipeline_parallel = pipeline_parallel
|
|
29
|
+
self.context_parallel = context_parallel
|
|
30
|
+
self.total_size = tensor_parallel * data_parallel * pipeline_parallel * context_parallel
|
|
31
|
+
|
|
32
|
+
self.parallel_sizes = {
|
|
33
|
+
"tp": self.tensor_parallel,
|
|
34
|
+
"pp": self.pipeline_parallel,
|
|
35
|
+
"dp": self.data_parallel,
|
|
36
|
+
"ep": self.expert_parallel,
|
|
37
|
+
"cp": self.context_parallel,
|
|
38
|
+
}
|
|
39
|
+
self.original_order = order
|
|
40
|
+
normalized_order = order.lower()
|
|
41
|
+
|
|
42
|
+
# 检查ep和dp是否相邻
|
|
43
|
+
if 'ep' in normalized_order:
|
|
44
|
+
if 'ep-dp' not in normalized_order and 'dp-ep' not in normalized_order:
|
|
45
|
+
logger.error(f"The ep and dp must be adjacent in order ({self.original_order}).")
|
|
46
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
47
|
+
|
|
48
|
+
# 检查所有非1的并行维度是否都在order中
|
|
49
|
+
for name in self.parallel_sizes.keys():
|
|
50
|
+
size = self.parallel_sizes[name]
|
|
51
|
+
if name not in normalized_order:
|
|
52
|
+
if size != 1:
|
|
53
|
+
logger.error(f"The parallel size ({name}) is ({size}), "
|
|
54
|
+
f"but it's not specified in order ({self.original_order}).")
|
|
55
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
56
|
+
else:
|
|
57
|
+
normalized_order += '-' + name
|
|
58
|
+
|
|
59
|
+
self.order_with_ep = normalized_order
|
|
60
|
+
self.order_without_ep = '-'.join([item for item in normalized_order.split('-') if item != 'ep'])
|
|
61
|
+
|
|
62
|
+
self.size_list_with_ep = []
|
|
63
|
+
self.size_list_without_ep = []
|
|
64
|
+
|
|
65
|
+
for item in normalized_order.split('-'):
|
|
66
|
+
if item == 'dp':
|
|
67
|
+
self.size_list_with_ep.append(self.data_parallel // self.expert_parallel)
|
|
68
|
+
self.size_list_without_ep.append(self.data_parallel)
|
|
69
|
+
elif item == 'ep':
|
|
70
|
+
self.size_list_with_ep.append(self.expert_parallel)
|
|
71
|
+
else:
|
|
72
|
+
self.size_list_with_ep.append(self.parallel_sizes[item])
|
|
73
|
+
self.size_list_without_ep.append(self.parallel_sizes[item])
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def create_mask(order_str: str, target_tokens: str) -> List[bool]:
|
|
77
|
+
order_elements = order_str.split('-')
|
|
78
|
+
target_elements = target_tokens.split('-')
|
|
79
|
+
mask = [False] * len(order_elements)
|
|
80
|
+
for token in target_elements:
|
|
81
|
+
mask[order_elements.index(token)] = True
|
|
82
|
+
return mask
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def create_masked_rank_groups(
|
|
86
|
+
total_size: int,
|
|
87
|
+
parallel_dims: List[int],
|
|
88
|
+
mask: List[bool],
|
|
89
|
+
) -> List[List[int]]:
|
|
90
|
+
def compute_prefix_products(dimensions: List[int], initial: int = 1) -> List[int]:
|
|
91
|
+
products = [initial]
|
|
92
|
+
current = initial
|
|
93
|
+
for dim in dimensions:
|
|
94
|
+
current *= dim
|
|
95
|
+
products.append(current)
|
|
96
|
+
return products
|
|
97
|
+
|
|
98
|
+
def calculate_inner_product(a: List[int], b: List[int]) -> int:
|
|
99
|
+
return sum(x * y for x, y in zip(a, b))
|
|
100
|
+
|
|
101
|
+
def decompose_index(index: int, shape: List[int], strides: List[int] = None) -> List[int]:
|
|
102
|
+
if strides is None:
|
|
103
|
+
strides = compute_prefix_products(shape)
|
|
104
|
+
indices = [(index // stride) % dim for dim, stride in zip(shape, strides)]
|
|
105
|
+
|
|
106
|
+
# 验证分解是否正确
|
|
107
|
+
if calculate_inner_product(indices, strides[:-1]) != index:
|
|
108
|
+
error_msg = f"The index {index} with shape {shape} doesn't match decomposed indices {indices}."
|
|
109
|
+
logger.error(error_msg)
|
|
110
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
111
|
+
|
|
112
|
+
return indices
|
|
113
|
+
|
|
114
|
+
# 分离被掩码和未被掩码的维度
|
|
115
|
+
masked_dims = [dim for dim, is_masked in zip(parallel_dims, mask) if is_masked]
|
|
116
|
+
unmasked_dims = [dim for dim, is_masked in zip(parallel_dims, mask) if not is_masked]
|
|
117
|
+
|
|
118
|
+
# 计算全局、掩码和未掩码的步长
|
|
119
|
+
global_strides = compute_prefix_products(parallel_dims)
|
|
120
|
+
masked_strides = [stride for stride, is_masked in zip(global_strides, mask) if is_masked]
|
|
121
|
+
unmasked_strides = [stride for stride, is_masked in zip(global_strides, mask) if not is_masked]
|
|
122
|
+
|
|
123
|
+
# 计算组大小和组数
|
|
124
|
+
group_dim = compute_prefix_products(masked_dims)[-1]
|
|
125
|
+
group_count = total_size // group_dim
|
|
126
|
+
|
|
127
|
+
# 生成所有组的rank
|
|
128
|
+
rank_groups = []
|
|
129
|
+
for group_idx in range(group_count):
|
|
130
|
+
decomposed_group = decompose_index(group_idx, unmasked_dims)
|
|
131
|
+
current_group = []
|
|
132
|
+
for in_group_idx in range(group_dim):
|
|
133
|
+
decomposed_rank = decompose_index(in_group_idx, masked_dims)
|
|
134
|
+
rank_value = (calculate_inner_product(decomposed_rank, masked_strides) +
|
|
135
|
+
calculate_inner_product(decomposed_group, unmasked_strides))
|
|
136
|
+
current_group.append(rank_value)
|
|
137
|
+
rank_groups.append(current_group)
|
|
138
|
+
|
|
139
|
+
return rank_groups
|
|
140
|
+
|
|
141
|
+
def generate_ranks(self, token: str, separate_ep: bool = False) -> List[List[int]]:
|
|
142
|
+
if separate_ep:
|
|
143
|
+
parallel_dims = self.size_list_with_ep
|
|
144
|
+
current_order = self.order_with_ep
|
|
145
|
+
else:
|
|
146
|
+
parallel_dims = self.size_list_without_ep
|
|
147
|
+
current_order = self.order_without_ep
|
|
148
|
+
|
|
149
|
+
mask = self.create_mask(current_order, token)
|
|
150
|
+
return self.create_masked_rank_groups(self.total_size, parallel_dims, mask)
|
|
151
|
+
|
|
152
|
+
def generate_all_ranks(self) -> dict:
|
|
153
|
+
result = {}
|
|
154
|
+
for token in ["dp", "pp", "tp"]:
|
|
155
|
+
result[token] = self.generate_ranks(token)
|
|
156
|
+
result[f"{token}_size"] = self.parallel_sizes[token]
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def get_tp_pp_default_groups(
|
|
161
|
+
total_world_size: int,
|
|
162
|
+
tensor_parallel_size: int = 1,
|
|
163
|
+
pipeline_parallel_size: int = 1,
|
|
164
|
+
order: str = "tp-cp-ep-dp-pp",
|
|
165
|
+
) -> tuple:
|
|
166
|
+
context_parallel_size = 1
|
|
167
|
+
expert_parallel_size = 1
|
|
168
|
+
|
|
169
|
+
# 检查world_size是否可被各并行维度的乘积整除
|
|
170
|
+
product = tensor_parallel_size * pipeline_parallel_size * context_parallel_size
|
|
171
|
+
if total_world_size % product != 0:
|
|
172
|
+
logger.error(f"The world size ({total_world_size}) is not divisible by "
|
|
173
|
+
f"{tensor_parallel_size} x {pipeline_parallel_size} x {context_parallel_size}.")
|
|
174
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
175
|
+
|
|
176
|
+
data_parallel_size = total_world_size // product
|
|
177
|
+
|
|
178
|
+
# 检查数据并行是否可被专家并行整除
|
|
179
|
+
if data_parallel_size % expert_parallel_size != 0:
|
|
180
|
+
logger.error(f"The data parallel size ({data_parallel_size}) is not divisible by expert parallel size.")
|
|
181
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
182
|
+
|
|
183
|
+
# 生成rank组
|
|
184
|
+
rank_creator = RankGroupGenerator(
|
|
185
|
+
tensor_parallel=tensor_parallel_size,
|
|
186
|
+
expert_parallel=expert_parallel_size,
|
|
187
|
+
data_parallel=data_parallel_size,
|
|
188
|
+
pipeline_parallel=pipeline_parallel_size,
|
|
189
|
+
context_parallel=context_parallel_size,
|
|
190
|
+
order=order,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return rank_creator.generate_ranks('tp'), rank_creator.generate_ranks('pp')
|
msprobe/core/common/utils.py
CHANGED
|
@@ -28,7 +28,7 @@ import numpy as np
|
|
|
28
28
|
from msprobe.core.common.const import Const, CompareConst
|
|
29
29
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
30
|
from msprobe.core.common.exceptions import MsprobeException
|
|
31
|
-
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
31
|
+
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json, load_construct_json)
|
|
32
32
|
from msprobe.core.common.log import logger
|
|
33
33
|
|
|
34
34
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
@@ -82,6 +82,9 @@ class MsprobeBaseException(Exception):
|
|
|
82
82
|
INVALID_STATE_ERROR = 35
|
|
83
83
|
INVALID_API_NAME_ERROR = 36
|
|
84
84
|
CROSS_FRAME_ERROR = 37
|
|
85
|
+
MISSING_THRESHOLD_ERROR = 38
|
|
86
|
+
WRONG_THRESHOLD_ERROR = 39
|
|
87
|
+
MULTIPROCESS_ERROR = 40
|
|
85
88
|
|
|
86
89
|
def __init__(self, code, error_info: str = ""):
|
|
87
90
|
super(MsprobeBaseException, self).__init__()
|
|
@@ -231,15 +234,6 @@ def check_compare_param(input_param, output_path, dump_mode, stack_mode):
|
|
|
231
234
|
_check_json(stack_json, input_param.get("stack_json_path"))
|
|
232
235
|
|
|
233
236
|
|
|
234
|
-
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
|
|
235
|
-
arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
|
|
236
|
-
arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'is_print_compare_log']
|
|
237
|
-
for arg, name in zip(arg_list, arg_names):
|
|
238
|
-
if not isinstance(arg, bool):
|
|
239
|
-
logger.error(f"Invalid input parameter, {name} which should be only bool type.")
|
|
240
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
241
|
-
|
|
242
|
-
|
|
243
237
|
def _check_json(json_file_handle, file_name):
|
|
244
238
|
tensor_line = json_file_handle.readline()
|
|
245
239
|
if not tensor_line:
|
|
@@ -283,6 +277,10 @@ def add_time_with_xlsx(name):
|
|
|
283
277
|
return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
284
278
|
|
|
285
279
|
|
|
280
|
+
def add_time_with_json(name):
|
|
281
|
+
return '{}_{}.json'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
282
|
+
|
|
283
|
+
|
|
286
284
|
def add_time_with_yaml(name):
|
|
287
285
|
return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
288
286
|
|
|
@@ -351,8 +349,18 @@ def get_stack_construct_by_dump_json_path(dump_json_path):
|
|
|
351
349
|
stack_json = os.path.join(directory, "stack.json")
|
|
352
350
|
construct_json = os.path.join(directory, "construct.json")
|
|
353
351
|
|
|
352
|
+
stack_json_exist = os.path.exists(stack_json)
|
|
353
|
+
construct_json_exist = os.path.exists(construct_json)
|
|
354
|
+
|
|
355
|
+
if not stack_json_exist and not construct_json_exist:
|
|
356
|
+
logger.info("stack.json and construct.json not found")
|
|
357
|
+
return {}, {}
|
|
358
|
+
if not stack_json_exist or not construct_json_exist:
|
|
359
|
+
logger.error("stack.json or construct.json not found, please check.")
|
|
360
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
361
|
+
|
|
354
362
|
stack = load_json(stack_json)
|
|
355
|
-
construct =
|
|
363
|
+
construct, _ = load_construct_json(construct_json)
|
|
356
364
|
return stack, construct
|
|
357
365
|
|
|
358
366
|
|
|
@@ -552,7 +560,7 @@ def check_token_range(token_range):
|
|
|
552
560
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
553
561
|
|
|
554
562
|
start, end = token_range
|
|
555
|
-
if not
|
|
563
|
+
if not is_int(start) or not is_int(end):
|
|
556
564
|
logger.error("Start and end in token_range must be integer.")
|
|
557
565
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
558
566
|
if start > end:
|
|
@@ -700,4 +708,3 @@ def check_process_num(process_num):
|
|
|
700
708
|
raise ValueError(f"process_num({process_num}) is not a positive integer")
|
|
701
709
|
if process_num > Const.MAX_PROCESS_NUM:
|
|
702
710
|
raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.")
|
|
703
|
-
|
msprobe/core/common_config.py
CHANGED
|
@@ -30,6 +30,7 @@ class CommonConfig:
|
|
|
30
30
|
self.level = json_config.get('level')
|
|
31
31
|
self.enable_dataloader = json_config.get('enable_dataloader', False)
|
|
32
32
|
self.async_dump = json_config.get("async_dump", False)
|
|
33
|
+
self.precision = json_config.get("precision", Const.DUMP_PRECISION_LOW)
|
|
33
34
|
self._check_config()
|
|
34
35
|
|
|
35
36
|
def _check_config(self):
|
|
@@ -51,6 +52,10 @@ class CommonConfig:
|
|
|
51
52
|
elif self.async_dump:
|
|
52
53
|
logger.warning("async_dump is True, it may cause OOM when dumping large tensor.")
|
|
53
54
|
|
|
55
|
+
if self.precision not in Const.DUMP_PRECISION_LIST:
|
|
56
|
+
logger.error_log_with_exp("precision is invalid, it should be one of {}".format(Const.DUMP_PRECISION_LIST),
|
|
57
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
58
|
+
|
|
54
59
|
|
|
55
60
|
class BaseConfig:
|
|
56
61
|
def __init__(self, json_config):
|