mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -1,95 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import re
|
|
17
|
-
import abc
|
|
18
|
-
import torch
|
|
19
|
-
|
|
20
|
-
from msprobe.pytorch.common.log import logger
|
|
21
|
-
|
|
22
|
-
# 用于存储所有validator实现类的注册表
|
|
23
|
-
config_validator_registry = {}
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def register_config_validator(cls):
|
|
27
|
-
"""装饰器 用于注册ConfigValidator的实现类"""
|
|
28
|
-
config_validator_registry[cls.__name__] = cls
|
|
29
|
-
return cls
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class ConfigValidator(metaclass=abc.ABCMeta):
|
|
33
|
-
@abc.abstractmethod
|
|
34
|
-
def check_pattern_match(self, config_spec: str):
|
|
35
|
-
pass
|
|
36
|
-
|
|
37
|
-
@abc.abstractmethod
|
|
38
|
-
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
39
|
-
pass
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@register_config_validator
|
|
43
|
-
class TensorValidator(ConfigValidator):
|
|
44
|
-
def check_pattern_match(self, config_spec: str):
|
|
45
|
-
pattern = re.compile(r"tensor")
|
|
46
|
-
return pattern.match(config_spec)
|
|
47
|
-
|
|
48
|
-
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
49
|
-
if not torch.is_tensor(actual_data):
|
|
50
|
-
raise ValueError(
|
|
51
|
-
f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@register_config_validator
|
|
55
|
-
class TupleValidator(ConfigValidator):
|
|
56
|
-
def check_pattern_match(self, config_spec: str):
|
|
57
|
-
pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
|
|
58
|
-
return pattern.match(config_spec)
|
|
59
|
-
|
|
60
|
-
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
61
|
-
length, index = pattern_match.groups()
|
|
62
|
-
if index is None:
|
|
63
|
-
index = 0
|
|
64
|
-
length, index = int(length), int(index)
|
|
65
|
-
|
|
66
|
-
if not (0 <= index < length):
|
|
67
|
-
raise ValueError(
|
|
68
|
-
f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
|
|
69
|
-
f"y must be greater than or equal to 0 and less than x.")
|
|
70
|
-
if not isinstance(actual_data, tuple):
|
|
71
|
-
raise ValueError(
|
|
72
|
-
f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
|
|
73
|
-
if len(actual_data) != length:
|
|
74
|
-
raise ValueError(
|
|
75
|
-
f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
|
|
76
|
-
f"actual is {len(actual_data)} please check.")
|
|
77
|
-
return index
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
|
|
81
|
-
focused_col = None
|
|
82
|
-
if not config_spec or not isinstance(config_spec, str):
|
|
83
|
-
return focused_col
|
|
84
|
-
for _, validator_cls in config_validator_registry.items():
|
|
85
|
-
config_validator = validator_cls()
|
|
86
|
-
pattern_match = config_validator.check_pattern_match(config_spec)
|
|
87
|
-
if pattern_match:
|
|
88
|
-
try:
|
|
89
|
-
focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
|
|
90
|
-
except ValueError as e:
|
|
91
|
-
logger.warning(f"config spec validate failed: {str(e)}")
|
|
92
|
-
return focused_col
|
|
93
|
-
logger.warning(f"config spec in {module_name} {data_type} not supported, "
|
|
94
|
-
f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
|
|
95
|
-
return focused_col
|
|
@@ -1,160 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import argparse
|
|
17
|
-
import os
|
|
18
|
-
import re
|
|
19
|
-
from glob import glob
|
|
20
|
-
|
|
21
|
-
import pandas as pd
|
|
22
|
-
|
|
23
|
-
from msprobe.pytorch.common.log import logger
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def parse_logfile(logfile):
|
|
27
|
-
grad_norm = []
|
|
28
|
-
step = []
|
|
29
|
-
with open(logfile) as f:
|
|
30
|
-
for line in f.readlines():
|
|
31
|
-
if 'consumed samples' in line:
|
|
32
|
-
grad_norm.append(float(re.findall('(?<=grad norm\: )[\d\.]*', line)[0]))
|
|
33
|
-
return grad_norm
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def parse_monitor_output(output_dir):
|
|
37
|
-
reduced = {}
|
|
38
|
-
unreduced = {}
|
|
39
|
-
for directory in glob(output_dir + '*'):
|
|
40
|
-
rank = int(re.findall('(?<=rank)[\d]*', directory)[0])
|
|
41
|
-
unreduced[rank] = []
|
|
42
|
-
reduced[rank] = []
|
|
43
|
-
for file in os.listdir(directory):
|
|
44
|
-
df = pd.read_csv(os.path.join(directory, file))
|
|
45
|
-
if '_unreduced_' in file:
|
|
46
|
-
unreduced[rank].append(df)
|
|
47
|
-
pass
|
|
48
|
-
elif '_reduced_' in file:
|
|
49
|
-
reduced[rank].append(df)
|
|
50
|
-
else:
|
|
51
|
-
logger.info(f'unexpected file {file} in {directory}')
|
|
52
|
-
return reduced, unreduced
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
|
|
56
|
-
steps = len(reduced[0])
|
|
57
|
-
world_size = len(reduced)
|
|
58
|
-
errors = []
|
|
59
|
-
for _, row in unreduced[0][0].iterrows():
|
|
60
|
-
param = row['param_name']
|
|
61
|
-
is_tp_duplicate = False
|
|
62
|
-
for step in range(2):
|
|
63
|
-
# sum reduced
|
|
64
|
-
reduced_mean = 0.
|
|
65
|
-
for rank in range(world_size):
|
|
66
|
-
if len(reduced[rank]) == 0:
|
|
67
|
-
continue
|
|
68
|
-
df = reduced[rank][step]
|
|
69
|
-
value = list(df[df['param_name'] == param]['mean'])
|
|
70
|
-
if not value:
|
|
71
|
-
if step == 0:
|
|
72
|
-
is_tp_duplicate = True
|
|
73
|
-
continue
|
|
74
|
-
reduced_mean += value[0]
|
|
75
|
-
|
|
76
|
-
# sum unreduced
|
|
77
|
-
unreduced_mean = 0.
|
|
78
|
-
for rank in range(world_size):
|
|
79
|
-
df = unreduced[rank][step]
|
|
80
|
-
value = list(df[df['param_name'] == param]['mean'])
|
|
81
|
-
if not value:
|
|
82
|
-
continue
|
|
83
|
-
unreduced_mean += list(df[df['param_name'] == param]['mean'])[0]
|
|
84
|
-
|
|
85
|
-
unreduced_mean /= dp_size
|
|
86
|
-
if is_tp_duplicate and (not sequence_parallel or 'embedding' in param):
|
|
87
|
-
unreduced_mean /= tp_size
|
|
88
|
-
try:
|
|
89
|
-
assert_equal(unreduced_mean, reduced_mean)
|
|
90
|
-
except AssertionError as e:
|
|
91
|
-
errors.append([param, step, e, is_tp_duplicate])
|
|
92
|
-
if errors:
|
|
93
|
-
logger.info(errors)
|
|
94
|
-
else:
|
|
95
|
-
logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitord.')
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
def assert_equal(a, b):
|
|
99
|
-
if b == 0 or a == 0:
|
|
100
|
-
return
|
|
101
|
-
if b == 0:
|
|
102
|
-
rel_diff = a
|
|
103
|
-
elif a == 0:
|
|
104
|
-
rel_diff = b
|
|
105
|
-
else:
|
|
106
|
-
rel_diff = abs(a / b - 1)
|
|
107
|
-
assert rel_diff < 0.01, f'{a}, {b}, {rel_diff}'
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def valid_total_norm(total_norm, reduced, duplicate_embedding):
|
|
111
|
-
steps = len(total_norm)
|
|
112
|
-
world_size = len(reduced)
|
|
113
|
-
errors = []
|
|
114
|
-
for step in range(steps):
|
|
115
|
-
calculated_norm = 0.
|
|
116
|
-
for rank in range(world_size):
|
|
117
|
-
if len(reduced[rank]) == 0:
|
|
118
|
-
if step == 0:
|
|
119
|
-
logger.info(f'rank {rank} is duplicated in dp group')
|
|
120
|
-
continue
|
|
121
|
-
for _, row in reduced[rank][step].iterrows():
|
|
122
|
-
if duplicate_embedding and 'word_embedding' in row['param_name']:
|
|
123
|
-
continue
|
|
124
|
-
calculated_norm += row['norm'] ** 2
|
|
125
|
-
try:
|
|
126
|
-
assert_equal(calculated_norm ** 0.5, total_norm[step])
|
|
127
|
-
except AssertionError as e:
|
|
128
|
-
errors.append([step, e])
|
|
129
|
-
if errors:
|
|
130
|
-
logger.info('total norm errors: ', errors)
|
|
131
|
-
else:
|
|
132
|
-
logger.info('grad norm in consist between training log and reduced gradients monitored')
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
if __name__ == "__main__":
|
|
136
|
-
parser = argparse.ArgumentParser()
|
|
137
|
-
parser.add_argument('--monitor_output', '-m', type=str, required=True,
|
|
138
|
-
help='path prefix to the output of monitor e.g. monitor_output/Aug12_07-16')
|
|
139
|
-
parser.add_argument('--logfile', '-l', type=str, required=True, help='path to the training log file')
|
|
140
|
-
parser.add_argument('--tp_size', '-t', type=int, required=True, help='tp parallel size')
|
|
141
|
-
parser.add_argument('--dp_size', '-d', type=int, required=True, help='dp parallel size')
|
|
142
|
-
parser.add_argument('--pp_size', '-p', type=int, required=True, help='pp parallel size')
|
|
143
|
-
parser.add_argument('--untie_embeddings_and_output_weights', '-u', action="store_true", default=False,
|
|
144
|
-
help='whether untie_embeddings_and_output_weights in pp parallel')
|
|
145
|
-
parser.add_argument('--sequence_parallel', '-s', action="store_true", default=False,
|
|
146
|
-
help='whether sequence parallel is enabled. Add -s to store true')
|
|
147
|
-
|
|
148
|
-
args = parser.parse_args()
|
|
149
|
-
|
|
150
|
-
assert args.tp_size > 0, 'if tp not enabled, set tp_size = 1'
|
|
151
|
-
assert args.dp_size > 0, 'if tp not enabled, set dp_size = 1'
|
|
152
|
-
assert args.pp_size > 0, 'if tp not enabled, set pp_size = 1'
|
|
153
|
-
|
|
154
|
-
total_norm = parse_logfile(args.logfile)
|
|
155
|
-
reduced, unreduced = parse_monitor_output(args.monitor_output)
|
|
156
|
-
|
|
157
|
-
duplicate_embedding = not args.untie_embeddings_and_output_weights and args.pp_size > 1
|
|
158
|
-
|
|
159
|
-
valid_total_norm(total_norm, reduced, duplicate_embedding)
|
|
160
|
-
valid_reduce(reduced, unreduced, args.tp_size, args.dp_size, args.sequence_parallel)
|