mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import wraps
17
+
18
+
19
+ class MegatronStepInfo:
20
+ is_megatron = False
21
+ is_forward = False
22
+ is_backward = False
23
+ forward_micro_step = -1
24
+ backward_micro_step = -1
25
+
26
+ @classmethod
27
+ def reset(cls):
28
+ """重置所有类属性到初始状态"""
29
+ cls.is_megatron = False
30
+ cls.is_forward = False
31
+ cls.is_backward = False
32
+ cls.forward_micro_step = -1
33
+ cls.backward_micro_step = -1
34
+
35
+
36
+ def wrap_megatron_step(func, is_forward=True):
37
+ @wraps(func)
38
+ def wrapped_func(*args, **kwargs):
39
+ if not MegatronStepInfo.is_megatron:
40
+ MegatronStepInfo.is_megatron = True
41
+ if is_forward:
42
+ MegatronStepInfo.is_forward = True
43
+ MegatronStepInfo.is_backward = False
44
+ MegatronStepInfo.forward_micro_step += 1
45
+ else:
46
+ MegatronStepInfo.is_forward = False
47
+ MegatronStepInfo.is_backward = True
48
+ MegatronStepInfo.backward_micro_step += 1
49
+ return func(*args, **kwargs)
50
+
51
+ return wrapped_func
52
+
53
+
54
+ def get_micro_step():
55
+ return MegatronStepInfo.forward_micro_step if MegatronStepInfo.is_forward else MegatronStepInfo.backward_micro_step
56
+
57
+
58
+ def is_megatron():
59
+ return MegatronStepInfo.is_megatron
@@ -0,0 +1,193 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List
17
+
18
+ from msprobe.core.common.log import logger
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+
21
+
22
+ class RankGroupGenerator(object):
23
+ def __init__(self, tensor_parallel: int, expert_parallel: int, data_parallel: int,
24
+ pipeline_parallel: int, context_parallel: int, order: str) -> None:
25
+ self.tensor_parallel = tensor_parallel
26
+ self.expert_parallel = expert_parallel
27
+ self.data_parallel = data_parallel
28
+ self.pipeline_parallel = pipeline_parallel
29
+ self.context_parallel = context_parallel
30
+ self.total_size = tensor_parallel * data_parallel * pipeline_parallel * context_parallel
31
+
32
+ self.parallel_sizes = {
33
+ "tp": self.tensor_parallel,
34
+ "pp": self.pipeline_parallel,
35
+ "dp": self.data_parallel,
36
+ "ep": self.expert_parallel,
37
+ "cp": self.context_parallel,
38
+ }
39
+ self.original_order = order
40
+ normalized_order = order.lower()
41
+
42
+ # 检查ep和dp是否相邻
43
+ if 'ep' in normalized_order:
44
+ if 'ep-dp' not in normalized_order and 'dp-ep' not in normalized_order:
45
+ logger.error(f"The ep and dp must be adjacent in order ({self.original_order}).")
46
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
47
+
48
+ # 检查所有非1的并行维度是否都在order中
49
+ for name in self.parallel_sizes.keys():
50
+ size = self.parallel_sizes[name]
51
+ if name not in normalized_order:
52
+ if size != 1:
53
+ logger.error(f"The parallel size ({name}) is ({size}), "
54
+ f"but it's not specified in order ({self.original_order}).")
55
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
56
+ else:
57
+ normalized_order += '-' + name
58
+
59
+ self.order_with_ep = normalized_order
60
+ self.order_without_ep = '-'.join([item for item in normalized_order.split('-') if item != 'ep'])
61
+
62
+ self.size_list_with_ep = []
63
+ self.size_list_without_ep = []
64
+
65
+ for item in normalized_order.split('-'):
66
+ if item == 'dp':
67
+ self.size_list_with_ep.append(self.data_parallel // self.expert_parallel)
68
+ self.size_list_without_ep.append(self.data_parallel)
69
+ elif item == 'ep':
70
+ self.size_list_with_ep.append(self.expert_parallel)
71
+ else:
72
+ self.size_list_with_ep.append(self.parallel_sizes[item])
73
+ self.size_list_without_ep.append(self.parallel_sizes[item])
74
+
75
+ @staticmethod
76
+ def create_mask(order_str: str, target_tokens: str) -> List[bool]:
77
+ order_elements = order_str.split('-')
78
+ target_elements = target_tokens.split('-')
79
+ mask = [False] * len(order_elements)
80
+ for token in target_elements:
81
+ mask[order_elements.index(token)] = True
82
+ return mask
83
+
84
+ @staticmethod
85
+ def create_masked_rank_groups(
86
+ total_size: int,
87
+ parallel_dims: List[int],
88
+ mask: List[bool],
89
+ ) -> List[List[int]]:
90
+ def compute_prefix_products(dimensions: List[int], initial: int = 1) -> List[int]:
91
+ products = [initial]
92
+ current = initial
93
+ for dim in dimensions:
94
+ current *= dim
95
+ products.append(current)
96
+ return products
97
+
98
+ def calculate_inner_product(a: List[int], b: List[int]) -> int:
99
+ return sum(x * y for x, y in zip(a, b))
100
+
101
+ def decompose_index(index: int, shape: List[int], strides: List[int] = None) -> List[int]:
102
+ if strides is None:
103
+ strides = compute_prefix_products(shape)
104
+ indices = [(index // stride) % dim for dim, stride in zip(shape, strides)]
105
+
106
+ # 验证分解是否正确
107
+ if calculate_inner_product(indices, strides[:-1]) != index:
108
+ error_msg = f"The index {index} with shape {shape} doesn't match decomposed indices {indices}."
109
+ logger.error(error_msg)
110
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
111
+
112
+ return indices
113
+
114
+ # 分离被掩码和未被掩码的维度
115
+ masked_dims = [dim for dim, is_masked in zip(parallel_dims, mask) if is_masked]
116
+ unmasked_dims = [dim for dim, is_masked in zip(parallel_dims, mask) if not is_masked]
117
+
118
+ # 计算全局、掩码和未掩码的步长
119
+ global_strides = compute_prefix_products(parallel_dims)
120
+ masked_strides = [stride for stride, is_masked in zip(global_strides, mask) if is_masked]
121
+ unmasked_strides = [stride for stride, is_masked in zip(global_strides, mask) if not is_masked]
122
+
123
+ # 计算组大小和组数
124
+ group_dim = compute_prefix_products(masked_dims)[-1]
125
+ group_count = total_size // group_dim
126
+
127
+ # 生成所有组的rank
128
+ rank_groups = []
129
+ for group_idx in range(group_count):
130
+ decomposed_group = decompose_index(group_idx, unmasked_dims)
131
+ current_group = []
132
+ for in_group_idx in range(group_dim):
133
+ decomposed_rank = decompose_index(in_group_idx, masked_dims)
134
+ rank_value = (calculate_inner_product(decomposed_rank, masked_strides) +
135
+ calculate_inner_product(decomposed_group, unmasked_strides))
136
+ current_group.append(rank_value)
137
+ rank_groups.append(current_group)
138
+
139
+ return rank_groups
140
+
141
+ def generate_ranks(self, token: str, separate_ep: bool = False) -> List[List[int]]:
142
+ if separate_ep:
143
+ parallel_dims = self.size_list_with_ep
144
+ current_order = self.order_with_ep
145
+ else:
146
+ parallel_dims = self.size_list_without_ep
147
+ current_order = self.order_without_ep
148
+
149
+ mask = self.create_mask(current_order, token)
150
+ return self.create_masked_rank_groups(self.total_size, parallel_dims, mask)
151
+
152
+ def generate_all_ranks(self) -> dict:
153
+ result = {}
154
+ for token in ["dp", "pp", "tp"]:
155
+ result[token] = self.generate_ranks(token)
156
+ result[f"{token}_size"] = self.parallel_sizes[token]
157
+ return result
158
+
159
+
160
+ def get_tp_pp_default_groups(
161
+ total_world_size: int,
162
+ tensor_parallel_size: int = 1,
163
+ pipeline_parallel_size: int = 1,
164
+ order: str = "tp-cp-ep-dp-pp",
165
+ ) -> tuple:
166
+ context_parallel_size = 1
167
+ expert_parallel_size = 1
168
+
169
+ # 检查world_size是否可被各并行维度的乘积整除
170
+ product = tensor_parallel_size * pipeline_parallel_size * context_parallel_size
171
+ if total_world_size % product != 0:
172
+ logger.error(f"The world size ({total_world_size}) is not divisible by "
173
+ f"{tensor_parallel_size} x {pipeline_parallel_size} x {context_parallel_size}.")
174
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
175
+
176
+ data_parallel_size = total_world_size // product
177
+
178
+ # 检查数据并行是否可被专家并行整除
179
+ if data_parallel_size % expert_parallel_size != 0:
180
+ logger.error(f"The data parallel size ({data_parallel_size}) is not divisible by expert parallel size.")
181
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
182
+
183
+ # 生成rank组
184
+ rank_creator = RankGroupGenerator(
185
+ tensor_parallel=tensor_parallel_size,
186
+ expert_parallel=expert_parallel_size,
187
+ data_parallel=data_parallel_size,
188
+ pipeline_parallel=pipeline_parallel_size,
189
+ context_parallel=context_parallel_size,
190
+ order=order,
191
+ )
192
+
193
+ return rank_creator.generate_ranks('tp'), rank_creator.generate_ranks('pp')
@@ -28,7 +28,7 @@ import numpy as np
28
28
  from msprobe.core.common.const import Const, CompareConst
29
29
  from msprobe.core.common.decorator import recursion_depth_decorator
30
30
  from msprobe.core.common.exceptions import MsprobeException
31
- from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
31
+ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json, load_construct_json)
32
32
  from msprobe.core.common.log import logger
33
33
 
34
34
  device = collections.namedtuple('device', ['type', 'index'])
@@ -82,6 +82,9 @@ class MsprobeBaseException(Exception):
82
82
  INVALID_STATE_ERROR = 35
83
83
  INVALID_API_NAME_ERROR = 36
84
84
  CROSS_FRAME_ERROR = 37
85
+ MISSING_THRESHOLD_ERROR = 38
86
+ WRONG_THRESHOLD_ERROR = 39
87
+ MULTIPROCESS_ERROR = 40
85
88
 
86
89
  def __init__(self, code, error_info: str = ""):
87
90
  super(MsprobeBaseException, self).__init__()
@@ -231,15 +234,6 @@ def check_compare_param(input_param, output_path, dump_mode, stack_mode):
231
234
  _check_json(stack_json, input_param.get("stack_json_path"))
232
235
 
233
236
 
234
- def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
235
- arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
236
- arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match', 'is_print_compare_log']
237
- for arg, name in zip(arg_list, arg_names):
238
- if not isinstance(arg, bool):
239
- logger.error(f"Invalid input parameter, {name} which should be only bool type.")
240
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
241
-
242
-
243
237
  def _check_json(json_file_handle, file_name):
244
238
  tensor_line = json_file_handle.readline()
245
239
  if not tensor_line:
@@ -283,6 +277,10 @@ def add_time_with_xlsx(name):
283
277
  return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
284
278
 
285
279
 
280
+ def add_time_with_json(name):
281
+ return '{}_{}.json'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
282
+
283
+
286
284
  def add_time_with_yaml(name):
287
285
  return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
288
286
 
@@ -351,8 +349,18 @@ def get_stack_construct_by_dump_json_path(dump_json_path):
351
349
  stack_json = os.path.join(directory, "stack.json")
352
350
  construct_json = os.path.join(directory, "construct.json")
353
351
 
352
+ stack_json_exist = os.path.exists(stack_json)
353
+ construct_json_exist = os.path.exists(construct_json)
354
+
355
+ if not stack_json_exist and not construct_json_exist:
356
+ logger.info("stack.json and construct.json not found")
357
+ return {}, {}
358
+ if not stack_json_exist or not construct_json_exist:
359
+ logger.error("stack.json or construct.json not found, please check.")
360
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
361
+
354
362
  stack = load_json(stack_json)
355
- construct = load_json(construct_json)
363
+ construct, _ = load_construct_json(construct_json)
356
364
  return stack, construct
357
365
 
358
366
 
@@ -552,7 +560,7 @@ def check_token_range(token_range):
552
560
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
553
561
 
554
562
  start, end = token_range
555
- if not isinstance(start, int) or not isinstance(end, int):
563
+ if not is_int(start) or not is_int(end):
556
564
  logger.error("Start and end in token_range must be integer.")
557
565
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
558
566
  if start > end:
@@ -700,4 +708,3 @@ def check_process_num(process_num):
700
708
  raise ValueError(f"process_num({process_num}) is not a positive integer")
701
709
  if process_num > Const.MAX_PROCESS_NUM:
702
710
  raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.")
703
-
@@ -30,6 +30,7 @@ class CommonConfig:
30
30
  self.level = json_config.get('level')
31
31
  self.enable_dataloader = json_config.get('enable_dataloader', False)
32
32
  self.async_dump = json_config.get("async_dump", False)
33
+ self.precision = json_config.get("precision", Const.DUMP_PRECISION_LOW)
33
34
  self._check_config()
34
35
 
35
36
  def _check_config(self):
@@ -51,6 +52,10 @@ class CommonConfig:
51
52
  elif self.async_dump:
52
53
  logger.warning("async_dump is True, it may cause OOM when dumping large tensor.")
53
54
 
55
+ if self.precision not in Const.DUMP_PRECISION_LIST:
56
+ logger.error_log_with_exp("precision is invalid, it should be one of {}".format(Const.DUMP_PRECISION_LIST),
57
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
58
+
54
59
 
55
60
  class BaseConfig:
56
61
  def __init__(self, json_config):