mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -12,16 +12,12 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
import itertools
|
|
16
|
-
import math
|
|
17
15
|
import re
|
|
18
|
-
import statistics
|
|
19
16
|
|
|
20
17
|
import torch
|
|
21
18
|
|
|
22
|
-
from msprobe.
|
|
23
|
-
from msprobe.pytorch.monitor.
|
|
24
|
-
from msprobe.core.common.log import logger
|
|
19
|
+
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
20
|
+
from msprobe.pytorch.monitor.utils import get_nan_tensor
|
|
25
21
|
|
|
26
22
|
|
|
27
23
|
def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
@@ -31,7 +27,9 @@ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
|
31
27
|
return f"{module_or_param_name}/rank{rank}/{tag}"
|
|
32
28
|
|
|
33
29
|
|
|
34
|
-
def squash_param_name(param_name):
|
|
30
|
+
def squash_param_name(param_name, enable=True):
|
|
31
|
+
if not enable:
|
|
32
|
+
return param_name
|
|
35
33
|
name = ''
|
|
36
34
|
for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
|
|
37
35
|
match = re.findall(pattern, param_name)
|
|
@@ -63,7 +61,7 @@ class TensorMetrics:
|
|
|
63
61
|
self.metrics = {} # tensor_tag --> []
|
|
64
62
|
self.cur_idx = {}
|
|
65
63
|
|
|
66
|
-
def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank
|
|
64
|
+
def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank):
|
|
67
65
|
"""get stats and insert into metrics dictionary"""
|
|
68
66
|
prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
|
|
69
67
|
for stat_op in stat_ops:
|
|
@@ -120,14 +118,14 @@ class NormMetric(Metric):
|
|
|
120
118
|
@staticmethod
|
|
121
119
|
def get_metric_value(tensor, eps):
|
|
122
120
|
return get_norm(tensor)
|
|
123
|
-
|
|
121
|
+
|
|
124
122
|
|
|
125
123
|
@register_config_metric("zeros")
|
|
126
124
|
class ZerosMetric(Metric):
|
|
127
125
|
@staticmethod
|
|
128
126
|
def get_metric_value(tensor, eps):
|
|
129
127
|
return get_zeros(tensor, eps)
|
|
130
|
-
|
|
128
|
+
|
|
131
129
|
|
|
132
130
|
@register_config_metric("nans")
|
|
133
131
|
class NaNsMetric(Metric):
|
|
@@ -146,48 +144,29 @@ class IdentMetric(Metric):
|
|
|
146
144
|
|
|
147
145
|
|
|
148
146
|
def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
147
|
+
"""
|
|
148
|
+
:param ops: ["op1", "op2"]
|
|
149
|
+
:param tag2tensor: {
|
|
150
|
+
'0:fc.input:0/actv': torch.randn([3, 4]),
|
|
151
|
+
'0:fc.output:0/actv': torch.randn([3, 3])
|
|
152
|
+
}
|
|
153
|
+
:param eps: float 1e-8
|
|
154
|
+
:param out_dict:{
|
|
155
|
+
'0:fc.input:0/actv': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
|
|
156
|
+
'0:fc.output:0/actv': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
|
|
157
|
+
}
|
|
158
|
+
:return: out_dict
|
|
159
|
+
"""
|
|
149
160
|
if out_dict is None:
|
|
150
161
|
out_dict = {}
|
|
151
162
|
for tag, tensor in tag2tensor.items():
|
|
152
163
|
if tag not in out_dict:
|
|
153
164
|
out_dict[tag] = {}
|
|
154
|
-
|
|
165
|
+
if not torch.is_tensor(tensor):
|
|
166
|
+
# Non-tensor in/output filled with nan.
|
|
167
|
+
out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
|
|
168
|
+
continue
|
|
169
|
+
for metric_name in ops:
|
|
155
170
|
fun_metric = config_metric_registry.get(metric_name)
|
|
156
171
|
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
157
172
|
return out_dict
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''):
|
|
161
|
-
if not metric_value:
|
|
162
|
-
return
|
|
163
|
-
tensors = []
|
|
164
|
-
tags = list(itertools.product(metric_value.keys(), ops))
|
|
165
|
-
for op2tensor in metric_value.values():
|
|
166
|
-
tensors.extend(op2tensor.values())
|
|
167
|
-
with torch.no_grad():
|
|
168
|
-
metric_list = torch.stack(tensors).cpu()
|
|
169
|
-
for tag, metric in zip(tags, metric_list):
|
|
170
|
-
summary_writer.add_scalar(tag, metric, step)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''):
|
|
174
|
-
write_metrics_base(ops, summary_writer, metric_value, step, prefix='')
|
|
175
|
-
|
|
176
|
-
if not summary_writer.header:
|
|
177
|
-
# 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
|
|
178
|
-
if prefix in {"actv", "actv_grad"}:
|
|
179
|
-
if prefix == "actv":
|
|
180
|
-
input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
|
|
181
|
-
else:
|
|
182
|
-
input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
|
|
183
|
-
ops_ = [MonitorConst.DOT.join(i[::-1]) for i in itertools.product(ops, input_and_output)]
|
|
184
|
-
summary_writer.header = ["module_name", "step", *ops_]
|
|
185
|
-
else:
|
|
186
|
-
summary_writer.header = ["param_name", "step", *ops]
|
|
187
|
-
|
|
188
|
-
for key in metric_value.keys():
|
|
189
|
-
if MonitorConst.VPP_SEP in key:
|
|
190
|
-
summary_writer.header.insert(0, 'vpp_stage')
|
|
191
|
-
break
|
|
192
|
-
summary_writer.write_csv(prefix, step)
|
|
193
|
-
summary_writer.header = []
|
|
@@ -17,7 +17,7 @@ import re
|
|
|
17
17
|
import abc
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from msprobe.
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
21
|
|
|
22
22
|
# 用于存储所有validator实现类的注册表
|
|
23
23
|
config_validator_registry = {}
|
|
@@ -79,6 +79,8 @@ class TupleValidator(ConfigValidator):
|
|
|
79
79
|
|
|
80
80
|
def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
|
|
81
81
|
focused_col = None
|
|
82
|
+
if not config_spec or not isinstance(config_spec, str):
|
|
83
|
+
return focused_col
|
|
82
84
|
for _, validator_cls in config_validator_registry.items():
|
|
83
85
|
config_validator = validator_cls()
|
|
84
86
|
pattern_match = config_validator.check_pattern_match(config_spec)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,27 +13,20 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from abc import ABC, abstractmethod
|
|
17
16
|
from collections import defaultdict
|
|
18
17
|
|
|
19
18
|
import torch
|
|
20
19
|
import torch.distributed as dist
|
|
21
20
|
|
|
22
|
-
from msprobe.
|
|
21
|
+
from msprobe.pytorch.common.log import logger
|
|
23
22
|
from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class OptimizerMon(object):
|
|
27
|
-
wrapped_optimizer = None
|
|
28
|
-
|
|
29
26
|
def __init__(self) -> None:
|
|
30
27
|
self.fp16_to_fp32_param = {}
|
|
31
28
|
self.is_stage3 = False
|
|
32
29
|
|
|
33
|
-
@classmethod
|
|
34
|
-
def set_wrapped_optimizer(cls, wrapped_optimizer):
|
|
35
|
-
cls.wrapped_optimizer = wrapped_optimizer
|
|
36
|
-
|
|
37
30
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
38
31
|
pass
|
|
39
32
|
|
|
@@ -83,11 +76,10 @@ class OptimizerMon(object):
|
|
|
83
76
|
ratio_dict = defaultdict()
|
|
84
77
|
param2name = defaultdict()
|
|
85
78
|
fp32_partitioned_groups_flat_grad = defaultdict()
|
|
86
|
-
mix_prec_opt = OptimizerMon.wrapped_optimizer
|
|
87
79
|
partition_id = dist.get_rank()
|
|
88
80
|
|
|
89
81
|
def get_flatten_grad(self, optimizer, group_idx):
|
|
90
|
-
if
|
|
82
|
+
if fp32_partitioned_groups_flat[group_idx].grad is None:
|
|
91
83
|
if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
|
|
92
84
|
fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
|
|
93
85
|
optimizer.averaged_gradients[group_idx],
|
|
@@ -102,7 +94,7 @@ class OptimizerMon(object):
|
|
|
102
94
|
return fp32_partitioned_groups_flat[group_idx].grad
|
|
103
95
|
|
|
104
96
|
for group_idx in range(len(fp32_partitioned_groups_flat)):
|
|
105
|
-
fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self,
|
|
97
|
+
fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
|
|
106
98
|
|
|
107
99
|
for name in params2name.values():
|
|
108
100
|
start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
|
|
@@ -111,9 +103,9 @@ class OptimizerMon(object):
|
|
|
111
103
|
fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
|
|
112
104
|
fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
|
|
113
105
|
param2name[fp32_param] = name
|
|
114
|
-
if not
|
|
106
|
+
if not torch_opt.state:
|
|
115
107
|
continue
|
|
116
|
-
state_param = list(
|
|
108
|
+
state_param = list(torch_opt.state.values())[group_idx]
|
|
117
109
|
exp_avg = state_param.get("exp_avg", None)
|
|
118
110
|
exp_avg_sq = state_param.get("exp_avg_sq", None)
|
|
119
111
|
if exp_avg is None or exp_avg_sq is None:
|
|
@@ -151,29 +143,33 @@ class MixPrecisionOptimizerMon(OptimizerMon):
|
|
|
151
143
|
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
152
144
|
"""
|
|
153
145
|
|
|
146
|
+
def map_fp16_tp_fp32_param(self, torch_opt):
|
|
147
|
+
for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
|
|
148
|
+
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
149
|
+
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
150
|
+
|
|
154
151
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
155
|
-
|
|
152
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
153
|
+
self.map_fp16_tp_fp32_param(torch_opt)
|
|
156
154
|
|
|
157
|
-
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
158
|
-
for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
|
|
159
|
-
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
160
|
-
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
161
155
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
162
156
|
|
|
163
157
|
|
|
164
158
|
class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
165
|
-
def
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
|
|
159
|
+
def map_fp16_tp_fp32_param(self, torch_opt):
|
|
160
|
+
if not (hasattr(torch_opt, "model_float16_groups") and
|
|
161
|
+
hasattr(torch_opt, "shard_fp32_from_float16_groups")):
|
|
169
162
|
raise Exception(
|
|
170
163
|
"megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
|
|
171
164
|
"if not, please check megatron-lm version")
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
165
|
+
for fp16_group, shard_fp32_group in zip(torch_opt.model_float16_groups,
|
|
166
|
+
torch_opt.shard_fp32_from_float16_groups):
|
|
167
|
+
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
168
|
+
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
169
|
+
|
|
170
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
171
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
172
|
+
self.map_fp16_tp_fp32_param(torch_opt)
|
|
177
173
|
|
|
178
174
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
179
175
|
|
|
@@ -183,15 +179,40 @@ class MegatronFP32OptimizerMon(OptimizerMon):
|
|
|
183
179
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
184
180
|
|
|
185
181
|
|
|
182
|
+
class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
183
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
184
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
185
|
+
for opt in torch_opt.chained_optimizers:
|
|
186
|
+
self.map_fp16_tp_fp32_param(opt)
|
|
187
|
+
|
|
188
|
+
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
189
|
+
torch_opt.state = {}
|
|
190
|
+
for opt in torch_opt.chained_optimizers:
|
|
191
|
+
torch_opt.state.update(opt.optimizer.state)
|
|
192
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
196
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
197
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
198
|
+
for opt in torch_opt.chained_optimizers:
|
|
199
|
+
self.map_fp16_tp_fp32_param(opt)
|
|
200
|
+
|
|
201
|
+
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
202
|
+
torch_opt.state = {}
|
|
203
|
+
for opt in torch_opt.chained_optimizers:
|
|
204
|
+
torch_opt.state.update(opt.optimizer.state)
|
|
205
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
206
|
+
|
|
207
|
+
|
|
186
208
|
class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
|
|
187
209
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
188
210
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
189
211
|
|
|
190
212
|
|
|
191
213
|
class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
|
|
192
|
-
def get_param_index(self, params2name, name2index):
|
|
193
|
-
|
|
194
|
-
fp16_groups = mix_prec_opt.fp16_partitioned_groups
|
|
214
|
+
def get_param_index(self, params2name, name2index, torch_opt):
|
|
215
|
+
fp16_groups = torch_opt.fp16_partitioned_groups
|
|
195
216
|
name2indices = defaultdict()
|
|
196
217
|
index_length = defaultdict()
|
|
197
218
|
index = 0
|
|
@@ -210,13 +231,11 @@ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
|
|
|
210
231
|
|
|
211
232
|
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
212
233
|
self.is_stage3 = True
|
|
213
|
-
|
|
214
|
-
fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
|
|
234
|
+
fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
|
|
215
235
|
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
216
236
|
|
|
217
237
|
|
|
218
238
|
class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
219
|
-
|
|
220
239
|
@staticmethod
|
|
221
240
|
def get_group_index(fp32_length, world_size, index):
|
|
222
241
|
for i in range(len(fp32_length) - 1):
|
|
@@ -229,12 +248,11 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
229
248
|
return sub_interval_start, min(sub_index, world_size - 1)
|
|
230
249
|
return fp32_length[-1], 0
|
|
231
250
|
|
|
232
|
-
def get_param_index(self, params2name, name2index):
|
|
233
|
-
|
|
234
|
-
padding = mix_prec_opt.groups_padding
|
|
251
|
+
def get_param_index(self, params2name, name2index, torch_opt):
|
|
252
|
+
padding = torch_opt.groups_padding
|
|
235
253
|
world_size = dist.get_world_size()
|
|
236
254
|
fp32_length = [0]
|
|
237
|
-
for fp32_group_index, single_partition_of_fp32_group in enumerate(
|
|
255
|
+
for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
|
|
238
256
|
fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
|
|
239
257
|
|
|
240
258
|
bf16_groups = []
|
|
@@ -242,7 +260,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
242
260
|
index_length = defaultdict()
|
|
243
261
|
index = 0
|
|
244
262
|
idx = 0
|
|
245
|
-
for group_idx, bf16_group in enumerate(
|
|
263
|
+
for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
|
|
246
264
|
bf16_groups.extend(bf16_group)
|
|
247
265
|
for param in bf16_group:
|
|
248
266
|
param_length = len(param.flatten())
|
|
@@ -250,7 +268,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
250
268
|
index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
|
|
251
269
|
index += param_length
|
|
252
270
|
idx += 1
|
|
253
|
-
group_length = len(bf16_groups) / len(
|
|
271
|
+
group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
|
|
254
272
|
for _, name in params2name.items():
|
|
255
273
|
name_index = name2index[name]
|
|
256
274
|
start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
|
|
@@ -264,32 +282,34 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
264
282
|
return name2indices
|
|
265
283
|
|
|
266
284
|
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
267
|
-
|
|
268
|
-
fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
|
|
285
|
+
fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
|
|
269
286
|
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
270
287
|
|
|
271
288
|
|
|
272
289
|
class DummyOptimizerMon(OptimizerMon):
|
|
273
290
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
274
|
-
return
|
|
291
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
275
292
|
|
|
276
293
|
|
|
277
294
|
class OptimizerMonFactory:
|
|
278
295
|
_optimizer_mon_map = {
|
|
279
|
-
"
|
|
280
|
-
"
|
|
281
|
-
"
|
|
282
|
-
"
|
|
283
|
-
"
|
|
296
|
+
"FP32Optimizer": MegatronFP32OptimizerMon,
|
|
297
|
+
"Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
|
|
298
|
+
"DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
299
|
+
"ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
300
|
+
"ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
|
|
301
|
+
"BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
|
|
302
|
+
"DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
|
|
284
303
|
"DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
|
|
285
|
-
"
|
|
304
|
+
"Adam": DummyOptimizerMon
|
|
286
305
|
}
|
|
287
306
|
|
|
288
307
|
@staticmethod
|
|
289
|
-
def create_optimizer_mon(
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
308
|
+
def create_optimizer_mon(optimizer):
|
|
309
|
+
# auto replace opt_ty
|
|
310
|
+
optimizer_class = optimizer.__class__.__name__
|
|
311
|
+
if optimizer_class == "ChainedOptimizer":
|
|
312
|
+
optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
|
|
313
|
+
|
|
314
|
+
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon)
|
|
315
|
+
return optimizer_mon_class(), optimizer_class
|
|
@@ -1,11 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
1
17
|
import os
|
|
2
18
|
import re
|
|
3
|
-
import argparse
|
|
4
19
|
from glob import glob
|
|
5
20
|
|
|
6
21
|
import pandas as pd
|
|
7
22
|
|
|
8
|
-
from msprobe.
|
|
23
|
+
from msprobe.pytorch.common.log import logger
|
|
9
24
|
|
|
10
25
|
|
|
11
26
|
def parse_logfile(logfile):
|
|
@@ -21,19 +36,19 @@ def parse_logfile(logfile):
|
|
|
21
36
|
def parse_monitor_output(output_dir):
|
|
22
37
|
reduced = {}
|
|
23
38
|
unreduced = {}
|
|
24
|
-
for
|
|
25
|
-
rank = int(re.findall('(?<=rank)[\d]*',
|
|
39
|
+
for directory in glob(output_dir + '*'):
|
|
40
|
+
rank = int(re.findall('(?<=rank)[\d]*', directory)[0])
|
|
26
41
|
unreduced[rank] = []
|
|
27
42
|
reduced[rank] = []
|
|
28
|
-
for file in os.listdir(
|
|
29
|
-
df = pd.read_csv(os.path.join(
|
|
43
|
+
for file in os.listdir(directory):
|
|
44
|
+
df = pd.read_csv(os.path.join(directory, file))
|
|
30
45
|
if '_unreduced_' in file:
|
|
31
46
|
unreduced[rank].append(df)
|
|
32
47
|
pass
|
|
33
48
|
elif '_reduced_' in file:
|
|
34
49
|
reduced[rank].append(df)
|
|
35
50
|
else:
|
|
36
|
-
logger.info(f'unexpected file {file} in {
|
|
51
|
+
logger.info(f'unexpected file {file} in {directory}')
|
|
37
52
|
return reduced, unreduced
|
|
38
53
|
|
|
39
54
|
|
|
@@ -41,7 +56,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
|
|
|
41
56
|
steps = len(reduced[0])
|
|
42
57
|
world_size = len(reduced)
|
|
43
58
|
errors = []
|
|
44
|
-
for
|
|
59
|
+
for _, row in unreduced[0][0].iterrows():
|
|
45
60
|
param = row['param_name']
|
|
46
61
|
is_tp_duplicate = False
|
|
47
62
|
for step in range(2):
|
|
@@ -103,7 +118,7 @@ def valid_total_norm(total_norm, reduced, duplicate_embedding):
|
|
|
103
118
|
if step == 0:
|
|
104
119
|
logger.info(f'rank {rank} is duplicated in dp group')
|
|
105
120
|
continue
|
|
106
|
-
for
|
|
121
|
+
for _, row in reduced[rank][step].iterrows():
|
|
107
122
|
if duplicate_embedding and 'word_embedding' in row['param_name']:
|
|
108
123
|
continue
|
|
109
124
|
calculated_norm += row['norm'] ** 2
|
msprobe/pytorch/monitor/utils.py
CHANGED
|
@@ -16,13 +16,27 @@ import inspect
|
|
|
16
16
|
from collections import namedtuple
|
|
17
17
|
from datetime import timezone, timedelta
|
|
18
18
|
from functools import wraps
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
import os
|
|
21
|
+
import re
|
|
19
22
|
|
|
20
23
|
import torch
|
|
21
24
|
|
|
22
25
|
from msprobe.core.common.const import MonitorConst, Const
|
|
23
|
-
from msprobe.
|
|
26
|
+
from msprobe.pytorch.common.log import logger
|
|
24
27
|
from msprobe.core.common.utils import is_int
|
|
28
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
25
29
|
|
|
30
|
+
|
|
31
|
+
device = "cpu"
|
|
32
|
+
try:
|
|
33
|
+
import torch_npu
|
|
34
|
+
device = "npu"
|
|
35
|
+
except ImportError:
|
|
36
|
+
if torch.cuda.is_available():
|
|
37
|
+
device = "cuda"
|
|
38
|
+
|
|
39
|
+
NAN_TENSOR_ON_DEVICE = None
|
|
26
40
|
FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
|
|
27
41
|
FILE_NAME_MAX_LENGTH = 255
|
|
28
42
|
DIRECTORY_MAX_LENGTH = 4096
|
|
@@ -39,6 +53,17 @@ class MsgConst:
|
|
|
39
53
|
SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
|
|
40
54
|
|
|
41
55
|
|
|
56
|
+
def get_output_base_dir():
|
|
57
|
+
return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_nan_tensor():
|
|
61
|
+
global NAN_TENSOR_ON_DEVICE
|
|
62
|
+
if not NAN_TENSOR_ON_DEVICE:
|
|
63
|
+
NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
|
|
64
|
+
return NAN_TENSOR_ON_DEVICE
|
|
65
|
+
|
|
66
|
+
|
|
42
67
|
def filter_special_chars(func):
|
|
43
68
|
@wraps(func)
|
|
44
69
|
def func_level(msg):
|
|
@@ -64,60 +89,19 @@ def get_param_struct(param):
|
|
|
64
89
|
return res
|
|
65
90
|
|
|
66
91
|
|
|
67
|
-
def is_recomputation():
|
|
68
|
-
"""Check if the current operation is in the re-computation phase.
|
|
69
|
-
|
|
70
|
-
This function inspects the current call stack to indicate whether the current operation is in the
|
|
71
|
-
re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
|
|
72
|
-
megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
|
|
73
|
-
mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
|
|
74
|
-
file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the
|
|
75
|
-
'torch/_tensor.py' file.
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
bool: True if in the re-computation phase, False otherwise.
|
|
79
|
-
"""
|
|
80
|
-
backward_function_indices = []
|
|
81
|
-
call_stack = inspect.stack()
|
|
82
|
-
|
|
83
|
-
# Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
|
|
84
|
-
for frame_info in call_stack:
|
|
85
|
-
if frame_info.function == Const.BACKWARD and frame_info.filename.endswith('torch/_tensor.py'):
|
|
86
|
-
del call_stack
|
|
87
|
-
return True
|
|
88
|
-
|
|
89
|
-
# Identify indices in the call stack where the specific function is being executed
|
|
90
|
-
for idx, frame_info in enumerate(call_stack):
|
|
91
|
-
if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
|
|
92
|
-
backward_function_indices.append(idx)
|
|
93
|
-
|
|
94
|
-
# Check if the execution is within 'torch/autograd/function.py' file
|
|
95
|
-
for idx in backward_function_indices:
|
|
96
|
-
# The Megatron and MindSpeed L0&L1 scenes
|
|
97
|
-
if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
|
|
98
|
-
del call_stack
|
|
99
|
-
return True
|
|
100
|
-
# The latest MindSpeed L2 and ModelLink scenes
|
|
101
|
-
if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
|
|
102
|
-
del call_stack
|
|
103
|
-
return True
|
|
104
|
-
|
|
105
|
-
del call_stack
|
|
106
|
-
return False
|
|
107
|
-
|
|
108
|
-
|
|
109
92
|
def validate_ops(ops):
|
|
110
93
|
if not isinstance(ops, list):
|
|
111
94
|
raise TypeError("ops should be a list")
|
|
112
|
-
if not ops:
|
|
113
|
-
raise TypeError(f"specify ops to calculate metrics. Optional ops: {MonitorConst.OP_LIST}")
|
|
114
|
-
|
|
115
95
|
valid_ops = []
|
|
116
96
|
for op in ops:
|
|
117
97
|
if op not in MonitorConst.OP_LIST:
|
|
118
98
|
logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
|
|
119
|
-
|
|
120
|
-
|
|
99
|
+
continue
|
|
100
|
+
valid_ops.append(op)
|
|
101
|
+
if not valid_ops:
|
|
102
|
+
default_op = MonitorConst.OP_LIST[0]
|
|
103
|
+
valid_ops.append(default_op)
|
|
104
|
+
logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
|
|
121
105
|
return valid_ops
|
|
122
106
|
|
|
123
107
|
|
|
@@ -164,6 +148,11 @@ def validate_mg_distribution(mg_distribution):
|
|
|
164
148
|
raise TypeError('mg_distribution should be a bool')
|
|
165
149
|
|
|
166
150
|
|
|
151
|
+
def validate_param_distribution(param_distribution):
|
|
152
|
+
if not isinstance(param_distribution, bool):
|
|
153
|
+
raise TypeError('param_distribution should be a bool')
|
|
154
|
+
|
|
155
|
+
|
|
167
156
|
def validate_cc_distribution(cc_distribution):
|
|
168
157
|
if not isinstance(cc_distribution, dict):
|
|
169
158
|
raise TypeError('cc_distribution should be a dictionary')
|
|
@@ -184,6 +173,11 @@ def validate_cc_distribution(cc_distribution):
|
|
|
184
173
|
raise TypeError(f'{key} of cc_distribution is not supported.')
|
|
185
174
|
|
|
186
175
|
|
|
176
|
+
def validate_squash_name(squash_name):
|
|
177
|
+
if not isinstance(squash_name, bool):
|
|
178
|
+
raise TypeError('squash_name should be a bool')
|
|
179
|
+
|
|
180
|
+
|
|
187
181
|
def validate_alert(alert):
|
|
188
182
|
if not isinstance(alert, dict):
|
|
189
183
|
raise TypeError('alert should be a dictionary')
|
|
@@ -240,6 +234,9 @@ def validate_config(config):
|
|
|
240
234
|
mg_distribution = config.get('mg_distribution', False)
|
|
241
235
|
validate_mg_distribution(mg_distribution)
|
|
242
236
|
|
|
237
|
+
param_distribution = config.get('param_distribution', False)
|
|
238
|
+
validate_param_distribution(param_distribution)
|
|
239
|
+
|
|
243
240
|
cc_distribution = config.get('cc_distribution', {})
|
|
244
241
|
validate_cc_distribution(cc_distribution)
|
|
245
242
|
|
|
@@ -248,3 +245,42 @@ def validate_config(config):
|
|
|
248
245
|
|
|
249
246
|
step_count_per_record = config.get('step_count_per_record', 1)
|
|
250
247
|
validate_step_count_per_record(step_count_per_record)
|
|
248
|
+
|
|
249
|
+
squash_name = config.get('squash_name', True)
|
|
250
|
+
validate_squash_name(squash_name)
|
|
251
|
+
|
|
252
|
+
if not targets:
|
|
253
|
+
if xy_distribution:
|
|
254
|
+
config["all_xy"] = True
|
|
255
|
+
config["targets"] = {"": {}}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def time_str2time_digit(time_str):
|
|
259
|
+
time_format = '%b%d_%H-%M-%S'
|
|
260
|
+
try:
|
|
261
|
+
time_digit = datetime.strptime(time_str, time_format)
|
|
262
|
+
except Exception as e:
|
|
263
|
+
raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
|
|
264
|
+
of existing output dirpath, like 'Dec03_21-34-40'.") from e
|
|
265
|
+
return time_digit
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def get_target_output_dir(monitor_path, time_start, time_end):
|
|
269
|
+
check_file_or_directory_path(monitor_path, isdir=True)
|
|
270
|
+
time_start = time_str2time_digit(time_start) if time_start is not None else time_start
|
|
271
|
+
time_end = time_str2time_digit(time_end) if time_end is not None else time_end
|
|
272
|
+
if time_start and time_end and time_start > time_end:
|
|
273
|
+
raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
|
|
274
|
+
result = {}
|
|
275
|
+
for dirname in os.listdir(monitor_path):
|
|
276
|
+
match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
|
|
277
|
+
if not match:
|
|
278
|
+
continue
|
|
279
|
+
time_tag = match.group(1)
|
|
280
|
+
rank = match.group(2)
|
|
281
|
+
target_time = time_str2time_digit(time_tag)
|
|
282
|
+
start_ok = time_start is None or target_time >= time_start
|
|
283
|
+
end_ok = time_end is None or target_time <= time_end
|
|
284
|
+
if start_ok and end_ok:
|
|
285
|
+
result[rank] = os.path.join(monitor_path, dirname)
|
|
286
|
+
return result
|