mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -16,8 +16,9 @@ import re
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
+
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
19
20
|
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
20
|
-
from msprobe.pytorch.monitor.utils import
|
|
21
|
+
from msprobe.pytorch.monitor.utils import get_nan_tensor
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
@@ -147,13 +148,13 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
|
147
148
|
"""
|
|
148
149
|
:param ops: ["op1", "op2"]
|
|
149
150
|
:param tag2tensor: {
|
|
150
|
-
'0:
|
|
151
|
-
'0:
|
|
151
|
+
'0:fc.input:0/actv': torch.randn([3, 4]),
|
|
152
|
+
'0:fc.output:0/actv': torch.randn([3, 3])
|
|
152
153
|
}
|
|
153
154
|
:param eps: float 1e-8
|
|
154
155
|
:param out_dict:{
|
|
155
|
-
'0:
|
|
156
|
-
'0:
|
|
156
|
+
'0:fc.input:0/actv': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
|
|
157
|
+
'0:fc.output:0/actv': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
|
|
157
158
|
}
|
|
158
159
|
:return: out_dict
|
|
159
160
|
"""
|
|
@@ -164,8 +165,10 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
|
164
165
|
out_dict[tag] = {}
|
|
165
166
|
if not torch.is_tensor(tensor):
|
|
166
167
|
# Non-tensor in/output filled with nan.
|
|
167
|
-
out_dict[tag].update({metric_name:
|
|
168
|
+
out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
|
|
168
169
|
continue
|
|
170
|
+
if is_float8_tensor(tensor):
|
|
171
|
+
tensor = tensor.float()
|
|
169
172
|
for metric_name in ops:
|
|
170
173
|
fun_metric = config_metric_registry.get(metric_name)
|
|
171
174
|
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
@@ -23,16 +23,10 @@ from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class OptimizerMon(object):
|
|
26
|
-
wrapped_optimizer = None
|
|
27
|
-
|
|
28
26
|
def __init__(self) -> None:
|
|
29
27
|
self.fp16_to_fp32_param = {}
|
|
30
28
|
self.is_stage3 = False
|
|
31
29
|
|
|
32
|
-
@classmethod
|
|
33
|
-
def set_wrapped_optimizer(cls, wrapped_optimizer):
|
|
34
|
-
cls.wrapped_optimizer = wrapped_optimizer
|
|
35
|
-
|
|
36
30
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
37
31
|
pass
|
|
38
32
|
|
|
@@ -82,7 +76,6 @@ class OptimizerMon(object):
|
|
|
82
76
|
ratio_dict = defaultdict()
|
|
83
77
|
param2name = defaultdict()
|
|
84
78
|
fp32_partitioned_groups_flat_grad = defaultdict()
|
|
85
|
-
mix_prec_opt = OptimizerMon.wrapped_optimizer
|
|
86
79
|
partition_id = dist.get_rank()
|
|
87
80
|
|
|
88
81
|
def get_flatten_grad(self, optimizer, group_idx):
|
|
@@ -101,7 +94,7 @@ class OptimizerMon(object):
|
|
|
101
94
|
return fp32_partitioned_groups_flat[group_idx].grad
|
|
102
95
|
|
|
103
96
|
for group_idx in range(len(fp32_partitioned_groups_flat)):
|
|
104
|
-
fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self,
|
|
97
|
+
fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
|
|
105
98
|
|
|
106
99
|
for name in params2name.values():
|
|
107
100
|
start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
|
|
@@ -110,9 +103,9 @@ class OptimizerMon(object):
|
|
|
110
103
|
fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
|
|
111
104
|
fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
|
|
112
105
|
param2name[fp32_param] = name
|
|
113
|
-
if not
|
|
106
|
+
if not torch_opt.state:
|
|
114
107
|
continue
|
|
115
|
-
state_param = list(
|
|
108
|
+
state_param = list(torch_opt.state.values())[group_idx]
|
|
116
109
|
exp_avg = state_param.get("exp_avg", None)
|
|
117
110
|
exp_avg_sq = state_param.get("exp_avg_sq", None)
|
|
118
111
|
if exp_avg is None or exp_avg_sq is None:
|
|
@@ -150,36 +143,33 @@ class MixPrecisionOptimizerMon(OptimizerMon):
|
|
|
150
143
|
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
151
144
|
"""
|
|
152
145
|
|
|
153
|
-
def map_fp16_tp_fp32_param(self,
|
|
154
|
-
for fp16_group, fp32_group in zip(
|
|
146
|
+
def map_fp16_tp_fp32_param(self, torch_opt):
|
|
147
|
+
for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
|
|
155
148
|
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
156
149
|
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
157
150
|
|
|
158
151
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
162
|
-
self.map_fp16_tp_fp32_param(mix_prec_opt)
|
|
152
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
153
|
+
self.map_fp16_tp_fp32_param(torch_opt)
|
|
163
154
|
|
|
164
155
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
165
156
|
|
|
166
157
|
|
|
167
158
|
class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
168
|
-
def map_fp16_tp_fp32_param(self,
|
|
169
|
-
if not (hasattr(
|
|
170
|
-
hasattr(
|
|
159
|
+
def map_fp16_tp_fp32_param(self, torch_opt):
|
|
160
|
+
if not (hasattr(torch_opt, "model_float16_groups") and
|
|
161
|
+
hasattr(torch_opt, "shard_fp32_from_float16_groups")):
|
|
171
162
|
raise Exception(
|
|
172
163
|
"megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
|
|
173
164
|
"if not, please check megatron-lm version")
|
|
174
|
-
for fp16_group, shard_fp32_group in zip(
|
|
175
|
-
|
|
165
|
+
for fp16_group, shard_fp32_group in zip(torch_opt.model_float16_groups,
|
|
166
|
+
torch_opt.shard_fp32_from_float16_groups):
|
|
176
167
|
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
177
168
|
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
178
169
|
|
|
179
170
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
self.map_fp16_tp_fp32_param(mix_prec_opt)
|
|
171
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
172
|
+
self.map_fp16_tp_fp32_param(torch_opt)
|
|
183
173
|
|
|
184
174
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
185
175
|
|
|
@@ -191,43 +181,89 @@ class MegatronFP32OptimizerMon(OptimizerMon):
|
|
|
191
181
|
|
|
192
182
|
class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
193
183
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
197
|
-
for opt in mix_prec_opt.chained_optimizers:
|
|
184
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
185
|
+
for opt in torch_opt.chained_optimizers:
|
|
198
186
|
self.map_fp16_tp_fp32_param(opt)
|
|
199
187
|
|
|
200
|
-
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
188
|
+
if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
|
|
201
189
|
torch_opt.state = {}
|
|
202
|
-
for opt in
|
|
190
|
+
for opt in torch_opt.chained_optimizers:
|
|
203
191
|
torch_opt.state.update(opt.optimizer.state)
|
|
204
192
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
205
193
|
|
|
206
194
|
|
|
207
195
|
class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
208
196
|
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
212
|
-
for opt in mix_prec_opt.chained_optimizers:
|
|
197
|
+
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
198
|
+
for opt in torch_opt.chained_optimizers:
|
|
213
199
|
self.map_fp16_tp_fp32_param(opt)
|
|
214
200
|
|
|
215
|
-
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
201
|
+
if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
|
|
216
202
|
torch_opt.state = {}
|
|
217
|
-
for opt in
|
|
203
|
+
for opt in torch_opt.chained_optimizers:
|
|
218
204
|
torch_opt.state.update(opt.optimizer.state)
|
|
219
205
|
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
220
206
|
|
|
221
207
|
|
|
222
208
|
class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
|
|
223
|
-
def
|
|
224
|
-
|
|
209
|
+
def get_group_index(self, torch_opt):
|
|
210
|
+
bit16_groups = torch_opt.bf16_groups
|
|
211
|
+
param2group = defaultdict()
|
|
212
|
+
for group_idx, bit16_group in enumerate(bit16_groups):
|
|
213
|
+
for param in bit16_group:
|
|
214
|
+
param2group[param] = group_idx
|
|
215
|
+
return param2group
|
|
216
|
+
|
|
217
|
+
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
218
|
+
param2group = self.get_group_index(torch_opt)
|
|
219
|
+
exp_avg_dict = defaultdict(float)
|
|
220
|
+
exp_avg_sq_dict = defaultdict(float)
|
|
221
|
+
update_dict = defaultdict()
|
|
222
|
+
ratio_dict = defaultdict()
|
|
223
|
+
|
|
224
|
+
param_slice_mappings = torch_opt.state_dict()['param_slice_mappings']
|
|
225
|
+
for param, name in params2name.items():
|
|
226
|
+
group_idx = param2group[param]
|
|
227
|
+
state = torch_opt.optimizer.state[torch_opt.fp32_groups_flat_partition[group_idx]]
|
|
228
|
+
if state.get('exp_avg', None) is None:
|
|
229
|
+
logger.warning(f"optimizer state is None. Something is wrong if this is not the first step")
|
|
230
|
+
break
|
|
231
|
+
param_slice_mapping = param_slice_mappings[group_idx]
|
|
232
|
+
hp_address = param_slice_mapping.get(torch_opt.param_names[param])
|
|
233
|
+
if hp_address is None:
|
|
234
|
+
continue
|
|
235
|
+
start = hp_address.start
|
|
236
|
+
numel = hp_address.numel
|
|
225
237
|
|
|
238
|
+
if monitor.mv_distribution:
|
|
239
|
+
exp_avg_dict[name] = state['exp_avg'].narrow(0, start, numel)
|
|
240
|
+
exp_avg_sq_dict[name] = state['exp_avg_sq'].narrow(0, start, numel)
|
|
241
|
+
if monitor.mg_direction:
|
|
242
|
+
exp_avg_dict[name] = state['exp'].narrow(0, start, numel)
|
|
243
|
+
if monitor.ur_distribution:
|
|
244
|
+
if len(torch_opt.param_groups) > 1:
|
|
245
|
+
logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
|
|
246
|
+
if 'step' in state:
|
|
247
|
+
step = state['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
248
|
+
elif 'step' in torch_opt.param_groups[0]:
|
|
249
|
+
step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
|
|
250
|
+
else:
|
|
251
|
+
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
252
|
+
continue
|
|
253
|
+
exp_avg = state['exp_avg'].narrow(0, start, numel)
|
|
254
|
+
exp_avg_sq = state['exp_avg_sq'].narrow(0, start, numel)
|
|
255
|
+
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
256
|
+
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
257
|
+
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
258
|
+
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
259
|
+
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
260
|
+
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
261
|
+
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
262
|
+
|
|
226
263
|
|
|
227
264
|
class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
|
|
228
|
-
def get_param_index(self, params2name, name2index):
|
|
229
|
-
|
|
230
|
-
fp16_groups = mix_prec_opt.fp16_partitioned_groups
|
|
265
|
+
def get_param_index(self, params2name, name2index, torch_opt):
|
|
266
|
+
fp16_groups = torch_opt.fp16_partitioned_groups
|
|
231
267
|
name2indices = defaultdict()
|
|
232
268
|
index_length = defaultdict()
|
|
233
269
|
index = 0
|
|
@@ -246,13 +282,11 @@ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
|
|
|
246
282
|
|
|
247
283
|
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
248
284
|
self.is_stage3 = True
|
|
249
|
-
|
|
250
|
-
fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
|
|
285
|
+
fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
|
|
251
286
|
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
252
287
|
|
|
253
288
|
|
|
254
289
|
class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
255
|
-
|
|
256
290
|
@staticmethod
|
|
257
291
|
def get_group_index(fp32_length, world_size, index):
|
|
258
292
|
for i in range(len(fp32_length) - 1):
|
|
@@ -265,12 +299,11 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
265
299
|
return sub_interval_start, min(sub_index, world_size - 1)
|
|
266
300
|
return fp32_length[-1], 0
|
|
267
301
|
|
|
268
|
-
def get_param_index(self, params2name, name2index):
|
|
269
|
-
|
|
270
|
-
padding = mix_prec_opt.groups_padding
|
|
302
|
+
def get_param_index(self, params2name, name2index, torch_opt):
|
|
303
|
+
padding = torch_opt.groups_padding
|
|
271
304
|
world_size = dist.get_world_size()
|
|
272
305
|
fp32_length = [0]
|
|
273
|
-
for fp32_group_index, single_partition_of_fp32_group in enumerate(
|
|
306
|
+
for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
|
|
274
307
|
fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
|
|
275
308
|
|
|
276
309
|
bf16_groups = []
|
|
@@ -278,7 +311,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
278
311
|
index_length = defaultdict()
|
|
279
312
|
index = 0
|
|
280
313
|
idx = 0
|
|
281
|
-
for group_idx, bf16_group in enumerate(
|
|
314
|
+
for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
|
|
282
315
|
bf16_groups.extend(bf16_group)
|
|
283
316
|
for param in bf16_group:
|
|
284
317
|
param_length = len(param.flatten())
|
|
@@ -286,7 +319,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
286
319
|
index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
|
|
287
320
|
index += param_length
|
|
288
321
|
idx += 1
|
|
289
|
-
group_length = len(bf16_groups) / len(
|
|
322
|
+
group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
|
|
290
323
|
for _, name in params2name.items():
|
|
291
324
|
name_index = name2index[name]
|
|
292
325
|
start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
|
|
@@ -300,8 +333,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
|
300
333
|
return name2indices
|
|
301
334
|
|
|
302
335
|
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
303
|
-
|
|
304
|
-
fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
|
|
336
|
+
fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
|
|
305
337
|
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
306
338
|
|
|
307
339
|
|
|
@@ -312,22 +344,23 @@ class DummyOptimizerMon(OptimizerMon):
|
|
|
312
344
|
|
|
313
345
|
class OptimizerMonFactory:
|
|
314
346
|
_optimizer_mon_map = {
|
|
315
|
-
"
|
|
316
|
-
"
|
|
317
|
-
"
|
|
318
|
-
"
|
|
319
|
-
"
|
|
320
|
-
"
|
|
321
|
-
"
|
|
347
|
+
"FP32Optimizer": MegatronFP32OptimizerMon,
|
|
348
|
+
"Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
|
|
349
|
+
"DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
350
|
+
"ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
351
|
+
"ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
|
|
352
|
+
"BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
|
|
353
|
+
"DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
|
|
322
354
|
"DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
|
|
323
|
-
"
|
|
355
|
+
"Adam": DummyOptimizerMon
|
|
324
356
|
}
|
|
325
357
|
|
|
326
358
|
@staticmethod
|
|
327
|
-
def create_optimizer_mon(
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
359
|
+
def create_optimizer_mon(optimizer):
|
|
360
|
+
# auto replace opt_ty
|
|
361
|
+
optimizer_class = optimizer.__class__.__name__
|
|
362
|
+
if optimizer_class == "ChainedOptimizer":
|
|
363
|
+
optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
|
|
364
|
+
|
|
365
|
+
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon)
|
|
366
|
+
return optimizer_mon_class(), optimizer_class
|
|
@@ -92,7 +92,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
|
|
|
92
92
|
if errors:
|
|
93
93
|
logger.info(errors)
|
|
94
94
|
else:
|
|
95
|
-
logger.info(f'grad mean is in consist between unreduced grad and reduced grad
|
|
95
|
+
logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitored.')
|
|
96
96
|
|
|
97
97
|
|
|
98
98
|
def assert_equal(a, b):
|
msprobe/pytorch/monitor/utils.py
CHANGED
|
@@ -25,7 +25,7 @@ import torch
|
|
|
25
25
|
from msprobe.core.common.const import MonitorConst, Const
|
|
26
26
|
from msprobe.pytorch.common.log import logger
|
|
27
27
|
from msprobe.core.common.utils import is_int
|
|
28
|
-
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
28
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
device = "cpu"
|
|
@@ -36,7 +36,7 @@ except ImportError:
|
|
|
36
36
|
if torch.cuda.is_available():
|
|
37
37
|
device = "cuda"
|
|
38
38
|
|
|
39
|
-
NAN_TENSOR_ON_DEVICE =
|
|
39
|
+
NAN_TENSOR_ON_DEVICE = None
|
|
40
40
|
FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
|
|
41
41
|
FILE_NAME_MAX_LENGTH = 255
|
|
42
42
|
DIRECTORY_MAX_LENGTH = 4096
|
|
@@ -57,6 +57,13 @@ def get_output_base_dir():
|
|
|
57
57
|
return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
58
58
|
|
|
59
59
|
|
|
60
|
+
def get_nan_tensor():
|
|
61
|
+
global NAN_TENSOR_ON_DEVICE
|
|
62
|
+
if not NAN_TENSOR_ON_DEVICE:
|
|
63
|
+
NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
|
|
64
|
+
return NAN_TENSOR_ON_DEVICE
|
|
65
|
+
|
|
66
|
+
|
|
60
67
|
def filter_special_chars(func):
|
|
61
68
|
@wraps(func)
|
|
62
69
|
def func_level(msg):
|
|
@@ -82,48 +89,6 @@ def get_param_struct(param):
|
|
|
82
89
|
return res
|
|
83
90
|
|
|
84
91
|
|
|
85
|
-
def is_recomputation():
|
|
86
|
-
"""Check if the current operation is in the re-computation phase.
|
|
87
|
-
|
|
88
|
-
This function inspects the current call stack to indicate whether the current operation is in the
|
|
89
|
-
re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
|
|
90
|
-
megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
|
|
91
|
-
mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
|
|
92
|
-
file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the
|
|
93
|
-
'torch/_tensor.py' file.
|
|
94
|
-
|
|
95
|
-
Returns:
|
|
96
|
-
bool: True if in the re-computation phase, False otherwise.
|
|
97
|
-
"""
|
|
98
|
-
backward_function_indices = []
|
|
99
|
-
call_stack = inspect.stack()
|
|
100
|
-
|
|
101
|
-
# Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
|
|
102
|
-
for frame_info in call_stack:
|
|
103
|
-
if frame_info.function == Const.BACKWARD and frame_info.filename.endswith('torch/_tensor.py'):
|
|
104
|
-
del call_stack
|
|
105
|
-
return True
|
|
106
|
-
|
|
107
|
-
# Identify indices in the call stack where the specific function is being executed
|
|
108
|
-
for idx, frame_info in enumerate(call_stack):
|
|
109
|
-
if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
|
|
110
|
-
backward_function_indices.append(idx)
|
|
111
|
-
|
|
112
|
-
# Check if the execution is within 'torch/autograd/function.py' file
|
|
113
|
-
for idx in backward_function_indices:
|
|
114
|
-
# The Megatron and MindSpeed L0&L1 scenes
|
|
115
|
-
if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
|
|
116
|
-
del call_stack
|
|
117
|
-
return True
|
|
118
|
-
# The latest MindSpeed L2 and ModelLink scenes
|
|
119
|
-
if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
|
|
120
|
-
del call_stack
|
|
121
|
-
return True
|
|
122
|
-
|
|
123
|
-
del call_stack
|
|
124
|
-
return False
|
|
125
|
-
|
|
126
|
-
|
|
127
92
|
def validate_ops(ops):
|
|
128
93
|
if not isinstance(ops, list):
|
|
129
94
|
raise TypeError("ops should be a list")
|
|
@@ -140,6 +105,15 @@ def validate_ops(ops):
|
|
|
140
105
|
return valid_ops
|
|
141
106
|
|
|
142
107
|
|
|
108
|
+
def validate_ndigits(ndigits):
|
|
109
|
+
if not ndigits:
|
|
110
|
+
return
|
|
111
|
+
if not is_int(ndigits) or ndigits <= 0:
|
|
112
|
+
raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
|
|
113
|
+
if ndigits > MonitorConst.MAX_NDIGITS:
|
|
114
|
+
raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
|
|
115
|
+
|
|
116
|
+
|
|
143
117
|
def validate_ranks(ranks):
|
|
144
118
|
if not isinstance(ranks, list):
|
|
145
119
|
raise TypeError("module_ranks should be a list")
|
|
@@ -241,9 +215,17 @@ def validate_step_count_per_record(step_count_per_record):
|
|
|
241
215
|
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
242
216
|
|
|
243
217
|
|
|
218
|
+
def validate_dynamic_on(dynamic_on):
|
|
219
|
+
if not isinstance(dynamic_on, bool):
|
|
220
|
+
raise TypeError('dynamic_on should be a bool')
|
|
221
|
+
|
|
222
|
+
|
|
244
223
|
def validate_config(config):
|
|
245
224
|
config['ops'] = validate_ops(config.get('ops', []))
|
|
246
225
|
|
|
226
|
+
ndigits = config.get('ndigits')
|
|
227
|
+
validate_ndigits(ndigits)
|
|
228
|
+
|
|
247
229
|
eps = config.get('eps', 1e-8)
|
|
248
230
|
if not isinstance(eps, float):
|
|
249
231
|
raise TypeError("eps should be a float")
|
|
@@ -281,9 +263,20 @@ def validate_config(config):
|
|
|
281
263
|
step_count_per_record = config.get('step_count_per_record', 1)
|
|
282
264
|
validate_step_count_per_record(step_count_per_record)
|
|
283
265
|
|
|
266
|
+
config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
|
|
267
|
+
MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
|
|
268
|
+
config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
|
|
269
|
+
MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
|
|
270
|
+
MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
|
|
271
|
+
config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
|
|
272
|
+
MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
|
|
273
|
+
|
|
284
274
|
squash_name = config.get('squash_name', True)
|
|
285
275
|
validate_squash_name(squash_name)
|
|
286
276
|
|
|
277
|
+
dynamic_on = config.get('dynamic_on', False)
|
|
278
|
+
validate_dynamic_on(dynamic_on)
|
|
279
|
+
|
|
287
280
|
if not targets:
|
|
288
281
|
if xy_distribution:
|
|
289
282
|
config["all_xy"] = True
|
|
@@ -292,6 +285,8 @@ def validate_config(config):
|
|
|
292
285
|
|
|
293
286
|
def time_str2time_digit(time_str):
|
|
294
287
|
time_format = '%b%d_%H-%M-%S'
|
|
288
|
+
if not isinstance(time_str, str):
|
|
289
|
+
raise TypeError(f"time_str:{time_str} should be a str")
|
|
295
290
|
try:
|
|
296
291
|
time_digit = datetime.strptime(time_str, time_format)
|
|
297
292
|
except Exception as e:
|
|
@@ -319,3 +314,40 @@ def get_target_output_dir(monitor_path, time_start, time_end):
|
|
|
319
314
|
if start_ok and end_ok:
|
|
320
315
|
result[rank] = os.path.join(monitor_path, dirname)
|
|
321
316
|
return result
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def chmod_tensorboard_dir(path):
|
|
320
|
+
"""
|
|
321
|
+
format配置为tensorboard时,需要补充文件权限设置
|
|
322
|
+
"""
|
|
323
|
+
try:
|
|
324
|
+
recursive_chmod(path)
|
|
325
|
+
except Exception as e:
|
|
326
|
+
logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def validate_set_monitor(grad_acc_steps, start_iteration):
|
|
330
|
+
"""
|
|
331
|
+
validate parameters of set_monitor.
|
|
332
|
+
"""
|
|
333
|
+
grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
|
|
334
|
+
MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
|
|
335
|
+
|
|
336
|
+
start_iteration = validate_int_arg(start_iteration, "start_iteration",
|
|
337
|
+
MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
|
|
338
|
+
return grad_acc_steps, start_iteration
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def validate_int_arg(value, name, minimum, default_value):
|
|
342
|
+
"""Validate int args, if any exception occurs, use the default value."""
|
|
343
|
+
if value is None:
|
|
344
|
+
return default_value
|
|
345
|
+
try:
|
|
346
|
+
if not is_int(value):
|
|
347
|
+
raise TypeError(f"{name} must be int")
|
|
348
|
+
if value < minimum:
|
|
349
|
+
raise ValueError(f"{name} must greater than {minimum}")
|
|
350
|
+
except Exception as e:
|
|
351
|
+
value = default_value
|
|
352
|
+
logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
|
|
353
|
+
return value
|
|
@@ -125,8 +125,6 @@ class Saver:
|
|
|
125
125
|
|
|
126
126
|
def write_summary_csv(self, test_result):
|
|
127
127
|
test_rows = []
|
|
128
|
-
if self.stack_info:
|
|
129
|
-
test_rows[0].append(self.COLUMN_STACK_INFO)
|
|
130
128
|
|
|
131
129
|
check_op_str_pattern_valid(test_result.api_name)
|
|
132
130
|
df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import json
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
|
+
import multiprocessing
|
|
19
20
|
from multiprocessing import Pool
|
|
20
21
|
|
|
21
22
|
import torch
|
|
@@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
52
53
|
return
|
|
53
54
|
if dump_path is None:
|
|
54
55
|
logger.error("Please set dump_path when dump_mode is config!")
|
|
56
|
+
raise DispatchException("Please set dump_path when dump_mode is config!")
|
|
55
57
|
check_file_or_directory_path(dump_path, True)
|
|
56
58
|
|
|
57
59
|
self.device_id = torch_npu._C._npu_getDevice()
|
|
@@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
85
87
|
self.get_ops(yaml_path)
|
|
86
88
|
|
|
87
89
|
self.lock = None
|
|
90
|
+
max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
|
|
91
|
+
if process_num > max_process_num:
|
|
92
|
+
logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
|
|
93
|
+
raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
|
|
94
|
+
f'but got {process_num}!')
|
|
88
95
|
if process_num > 0:
|
|
89
96
|
self.pool = Pool(process_num)
|
|
90
97
|
if debug:
|
|
@@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
115
122
|
if len(json_line_data) == 0:
|
|
116
123
|
break
|
|
117
124
|
msg = json.loads(json_line_data)
|
|
125
|
+
if len(msg) < 2:
|
|
126
|
+
raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
|
|
118
127
|
self.all_summary[msg[0]] = msg[1]
|
|
119
128
|
fp_handle.close()
|
|
120
129
|
|
|
@@ -19,6 +19,8 @@ import os
|
|
|
19
19
|
from datetime import datetime, timezone
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
22
24
|
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
|
|
23
25
|
from msprobe.pytorch.common.log import logger
|
|
24
26
|
|
|
@@ -91,6 +93,7 @@ def support_basic_type(data):
|
|
|
91
93
|
return False
|
|
92
94
|
|
|
93
95
|
|
|
96
|
+
@recursion_depth_decorator("dump_data")
|
|
94
97
|
def dump_data(data, prefix, dump_path):
|
|
95
98
|
if isinstance(data, (tuple, list)) and data:
|
|
96
99
|
for i, item in enumerate(data):
|
|
@@ -27,8 +27,10 @@ else:
|
|
|
27
27
|
pta_cpu_device = torch.device("cpu")
|
|
28
28
|
|
|
29
29
|
from msprobe.core.common.const import CompareConst
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
31
|
from msprobe.pytorch.common.log import logger
|
|
31
32
|
|
|
33
|
+
|
|
32
34
|
cpu_device = torch._C.device("cpu")
|
|
33
35
|
COLOR_RED = '\033[31m'
|
|
34
36
|
COLOR_GREEN = '\033[32m'
|
|
@@ -85,6 +87,7 @@ def get_callstack():
|
|
|
85
87
|
return callstack
|
|
86
88
|
|
|
87
89
|
|
|
90
|
+
@recursion_depth_decorator("data_to_cpu")
|
|
88
91
|
def data_to_cpu(data, deep, data_cpu):
|
|
89
92
|
global cpu_device
|
|
90
93
|
list_cpu = []
|
|
@@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd):
|
|
|
45
45
|
|
|
46
46
|
@catch_exception
|
|
47
47
|
def default(self, line=""):
|
|
48
|
-
self.
|
|
49
|
-
return False
|
|
50
|
-
|
|
51
|
-
@catch_exception
|
|
52
|
-
def do_run(self, line=""):
|
|
53
|
-
self.util.execute_command(line)
|
|
48
|
+
self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n")
|
|
54
49
|
|
|
55
50
|
@catch_exception
|
|
56
51
|
def do_vc(self, line=""):
|
|
@@ -119,6 +119,7 @@ class Util:
|
|
|
119
119
|
|
|
120
120
|
@staticmethod
|
|
121
121
|
def deal_with_dir_or_file_inconsistency(output_path):
|
|
122
|
+
logger.warning(f"Trying to delete {output_path}")
|
|
122
123
|
remove_path(output_path)
|
|
123
124
|
raise ParseException("Inconsistent directory structure or file.")
|
|
124
125
|
|
|
@@ -264,7 +265,7 @@ class Util:
|
|
|
264
265
|
match = re_pattern.match(name)
|
|
265
266
|
if not match:
|
|
266
267
|
continue
|
|
267
|
-
if extern_pattern != '' and re_pattern.match(extern_pattern) and not
|
|
268
|
+
if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern):
|
|
268
269
|
continue
|
|
269
270
|
file_list[name] = gen_info_func(name, match, file["root"])
|
|
270
271
|
return file_list
|