mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import atexit
|
|
17
17
|
import os
|
|
18
|
+
import traceback
|
|
18
19
|
|
|
19
20
|
from msprobe.core.data_dump.scope import ScopeFactory
|
|
20
21
|
from msprobe.core.data_dump.json_writer import DataWriter
|
|
@@ -41,7 +42,7 @@ class DataCollector:
|
|
|
41
42
|
self.backward_module_names = {}
|
|
42
43
|
self.optimizer_status = ""
|
|
43
44
|
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
44
|
-
atexit.register(self.
|
|
45
|
+
atexit.register(self.write_json_at_exit)
|
|
45
46
|
|
|
46
47
|
@property
|
|
47
48
|
def dump_data_dir(self):
|
|
@@ -78,6 +79,11 @@ class DataCollector:
|
|
|
78
79
|
def write_json(self):
|
|
79
80
|
self.data_writer.write_json()
|
|
80
81
|
|
|
82
|
+
def write_json_at_exit(self):
|
|
83
|
+
if self.config.async_dump and self.config.task == Const.TENSOR:
|
|
84
|
+
self.data_processor.dump_async_data()
|
|
85
|
+
self.data_writer.write_json()
|
|
86
|
+
|
|
81
87
|
def update_data(self, name, data_info):
|
|
82
88
|
msg = f"msprobe is collecting data on {name}."
|
|
83
89
|
if self.config.task == Const.OVERFLOW_CHECK:
|
|
@@ -89,88 +95,155 @@ class DataCollector:
|
|
|
89
95
|
logger.debug(msg)
|
|
90
96
|
self.data_writer.update_data(data_info)
|
|
91
97
|
|
|
92
|
-
def
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
96
|
-
self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
|
|
97
|
-
return
|
|
98
|
-
|
|
99
|
-
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
100
|
-
return
|
|
98
|
+
def call_stack_collect(self, name):
|
|
99
|
+
stack_info = self.data_processor.analyze_api_call_stack(name)
|
|
100
|
+
self.data_writer.update_stack(name, stack_info)
|
|
101
101
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
102
|
+
def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
103
|
+
try:
|
|
104
|
+
|
|
105
|
+
if self.config.task == Const.FREE_BENCHMARK:
|
|
106
|
+
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
107
|
+
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
108
|
+
self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
data_info = {}
|
|
115
|
+
if self.config.task != Const.STRUCTURE:
|
|
116
|
+
data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
|
|
117
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
118
|
+
if self.config.level == Const.LEVEL_L2:
|
|
119
|
+
return
|
|
120
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
121
|
+
|
|
122
|
+
except Exception:
|
|
123
|
+
tb = traceback.format_exc()
|
|
124
|
+
self.data_writer.write_error_log(
|
|
125
|
+
f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
126
|
+
)
|
|
109
127
|
|
|
110
128
|
def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
data_info =
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
129
|
+
try:
|
|
130
|
+
|
|
131
|
+
self.update_construct(name)
|
|
132
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
data_info = {}
|
|
136
|
+
if self.config.task != Const.STRUCTURE:
|
|
137
|
+
data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
|
|
138
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
139
|
+
if self.config.level == Const.LEVEL_L2:
|
|
140
|
+
return
|
|
141
|
+
self.call_stack_collect(name)
|
|
142
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
143
|
+
|
|
144
|
+
except Exception:
|
|
145
|
+
tb = traceback.format_exc()
|
|
146
|
+
self.data_writer.write_error_log(
|
|
147
|
+
f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def forward_data_collect_only_tensor(self, name, module, pid, module_input_output):
|
|
151
|
+
try:
|
|
152
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
153
|
+
return
|
|
154
|
+
self.data_processor.analyze_forward(name, module, module_input_output)
|
|
155
|
+
|
|
156
|
+
except Exception:
|
|
157
|
+
tb = traceback.format_exc()
|
|
158
|
+
self.data_writer.write_error_log(
|
|
159
|
+
f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
|
|
160
|
+
)
|
|
123
161
|
|
|
124
162
|
def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
163
|
+
try:
|
|
164
|
+
|
|
165
|
+
self.update_construct(name)
|
|
166
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
167
|
+
return
|
|
168
|
+
data_info = {}
|
|
169
|
+
if self.config.task != Const.STRUCTURE:
|
|
170
|
+
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
171
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
172
|
+
self.call_stack_collect(name)
|
|
173
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
174
|
+
|
|
175
|
+
except Exception:
|
|
176
|
+
tb = traceback.format_exc()
|
|
177
|
+
self.data_writer.write_error_log(
|
|
178
|
+
f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None):
|
|
182
|
+
try:
|
|
183
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
184
|
+
return
|
|
185
|
+
self.data_processor.analyze_backward(name, module, module_input_output)
|
|
186
|
+
|
|
187
|
+
except Exception:
|
|
188
|
+
tb = traceback.format_exc()
|
|
189
|
+
self.data_writer.write_error_log(
|
|
190
|
+
f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
|
|
191
|
+
)
|
|
135
192
|
|
|
136
193
|
def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
194
|
+
try:
|
|
195
|
+
self.update_construct(name)
|
|
196
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
197
|
+
return
|
|
198
|
+
data_info = {}
|
|
199
|
+
if self.config.task != Const.STRUCTURE:
|
|
200
|
+
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
201
|
+
if self.config.level == Const.LEVEL_L2:
|
|
202
|
+
return
|
|
203
|
+
if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
|
|
204
|
+
module_name = name.rsplit(Const.SEP, 2)[0]
|
|
205
|
+
self.backward_module_names[module_name] = True
|
|
206
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
207
|
+
|
|
208
|
+
except Exception:
|
|
209
|
+
tb = traceback.format_exc()
|
|
210
|
+
self.data_writer.write_error_log(
|
|
211
|
+
f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
212
|
+
)
|
|
152
213
|
|
|
153
214
|
def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
215
|
+
try:
|
|
216
|
+
self.update_construct(name)
|
|
217
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
218
|
+
return
|
|
219
|
+
data_info = {}
|
|
220
|
+
if self.config.task != Const.STRUCTURE:
|
|
221
|
+
data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
|
|
222
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
223
|
+
self.handle_data(name, data_info)
|
|
224
|
+
|
|
225
|
+
except Exception:
|
|
226
|
+
tb = traceback.format_exc()
|
|
227
|
+
self.data_writer.write_error_log(
|
|
228
|
+
f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
229
|
+
)
|
|
163
230
|
|
|
164
231
|
def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
232
|
+
try:
|
|
233
|
+
self.update_construct(name)
|
|
234
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
235
|
+
return
|
|
236
|
+
data_info = {}
|
|
237
|
+
if self.config.task != Const.STRUCTURE:
|
|
238
|
+
data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
|
|
239
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
240
|
+
self.handle_data(name, data_info)
|
|
241
|
+
|
|
242
|
+
except Exception:
|
|
243
|
+
tb = traceback.format_exc()
|
|
244
|
+
self.data_writer.write_error_log(
|
|
245
|
+
f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
|
|
246
|
+
)
|
|
174
247
|
|
|
175
248
|
def update_construct(self, name):
|
|
176
249
|
if self.config.level not in DataCollector.level_without_construct:
|
|
@@ -180,7 +253,10 @@ class DataCollector:
|
|
|
180
253
|
self.optimizer_status_first_start[self.optimizer_status] = False
|
|
181
254
|
self.data_writer.update_construct({name: self.optimizer_status})
|
|
182
255
|
else:
|
|
183
|
-
self.
|
|
256
|
+
if self.config.level == Const.LEVEL_MIX and \
|
|
257
|
+
not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
|
|
258
|
+
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
259
|
+
|
|
184
260
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
185
261
|
|
|
186
262
|
def handle_data(self, name, data_info, flush=False):
|
|
@@ -203,28 +279,33 @@ class DataCollector:
|
|
|
203
279
|
self.data_processor.update_iter(current_iter)
|
|
204
280
|
|
|
205
281
|
def params_data_collect(self, name, param_name, pid, data):
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
282
|
+
try:
|
|
283
|
+
grad_name = name + Const.SEP + Const.PARAMS_GRAD
|
|
284
|
+
self.update_api_or_module_name(grad_name)
|
|
285
|
+
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
286
|
+
if self.data_writer.cache_data.get("data"):
|
|
287
|
+
self.data_writer.cache_data.get("data").pop(grad_name, None)
|
|
288
|
+
return
|
|
289
|
+
data_info = self.data_processor.analyze_params(grad_name, param_name, data)
|
|
290
|
+
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
291
|
+
except Exception:
|
|
292
|
+
tb = traceback.format_exc()
|
|
293
|
+
self.data_writer.write_error_log(
|
|
294
|
+
f"[ERROR] params_data_collect failed: "
|
|
295
|
+
f"name={name}, param_name={param_name}, pid={pid}\n{tb}"
|
|
296
|
+
)
|
|
218
297
|
|
|
219
298
|
def debug_data_collect_forward(self, variable, name_with_count):
|
|
220
|
-
|
|
221
299
|
data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
|
|
222
|
-
|
|
300
|
+
name_with_count_category = name_with_count + Const.SEP + Const.DEBUG
|
|
301
|
+
self.data_writer.update_debug({name_with_count_category: data_info})
|
|
223
302
|
|
|
224
303
|
def debug_data_collect_backward(self, variable, grad_name_with_count):
|
|
225
304
|
# prepare all None nested data structure
|
|
226
305
|
all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
|
|
227
|
-
|
|
306
|
+
grad_name_with_count_category = grad_name_with_count + Const.SEP + Const.DEBUG
|
|
307
|
+
self.data_writer.update_debug({grad_name_with_count_category: all_none_data_info})
|
|
228
308
|
|
|
229
309
|
# register tensor backward hook
|
|
230
|
-
self.data_processor.analyze_debug_backward(variable,
|
|
310
|
+
self.data_processor.analyze_debug_backward(variable, grad_name_with_count_category,
|
|
311
|
+
self.data_writer.cache_debug['data'])
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,17 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import copy
|
|
16
17
|
import inspect
|
|
17
18
|
import os
|
|
18
19
|
from dataclasses import dataclass, is_dataclass
|
|
19
|
-
from typing import Tuple, Dict, Optional, Any
|
|
20
20
|
from functools import partial
|
|
21
|
-
import
|
|
22
|
-
from typing import Union
|
|
21
|
+
from typing import Tuple, Dict, Optional, Any, Union
|
|
23
22
|
|
|
24
23
|
import numpy as np
|
|
25
24
|
|
|
26
25
|
from msprobe.core.common.const import Const
|
|
26
|
+
from msprobe.core.common.file_utils import save_npy
|
|
27
27
|
from msprobe.core.common.log import logger
|
|
28
28
|
from msprobe.core.common.utils import convert_tuple, CompareException
|
|
29
29
|
|
|
@@ -79,21 +79,17 @@ class ModuleBackwardOutputs:
|
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
class TensorStatInfo:
|
|
82
|
-
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None
|
|
82
|
+
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
83
83
|
self.max = max_val
|
|
84
84
|
self.min = min_val
|
|
85
85
|
self.mean = mean_val
|
|
86
86
|
self.norm = norm_val
|
|
87
|
-
self.stack_tensor_stat = stack_tensor_stat
|
|
88
87
|
|
|
89
88
|
|
|
90
89
|
class BaseDataProcessor:
|
|
91
90
|
_recursive_key_stack = []
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
bool, int, float, str, slice,
|
|
95
|
-
type(Ellipsis)
|
|
96
|
-
)
|
|
91
|
+
builtin_type = (bool, int, float, str, slice, type(Ellipsis))
|
|
92
|
+
np_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray)
|
|
97
93
|
|
|
98
94
|
def __init__(self, config, data_writer):
|
|
99
95
|
self.data_writer = data_writer
|
|
@@ -120,7 +116,10 @@ class BaseDataProcessor:
|
|
|
120
116
|
@staticmethod
|
|
121
117
|
def analyze_api_call_stack(name):
|
|
122
118
|
try:
|
|
123
|
-
|
|
119
|
+
if name.startswith("Primitive"):
|
|
120
|
+
api_stack = inspect.stack()[4:]
|
|
121
|
+
else:
|
|
122
|
+
api_stack = inspect.stack()[5:]
|
|
124
123
|
except Exception as e:
|
|
125
124
|
logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.")
|
|
126
125
|
api_stack = None
|
|
@@ -129,12 +128,14 @@ class BaseDataProcessor:
|
|
|
129
128
|
for (_, path, line, func, code, _) in api_stack:
|
|
130
129
|
if not code:
|
|
131
130
|
continue
|
|
131
|
+
if any(filter_path in path for filter_path in Const.STACK_FILTER_KEYWORDS) and \
|
|
132
|
+
Const.CALL_STACK_FLAG not in path:
|
|
133
|
+
continue
|
|
132
134
|
stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}"
|
|
133
135
|
stack_str.append(stack_line)
|
|
134
136
|
else:
|
|
135
137
|
stack_str.append(Const.WITHOUT_CALL_STACK)
|
|
136
|
-
|
|
137
|
-
return stack_info_struct
|
|
138
|
+
return tuple(stack_str)
|
|
138
139
|
|
|
139
140
|
@staticmethod
|
|
140
141
|
def transfer_type(data):
|
|
@@ -178,20 +179,8 @@ class BaseDataProcessor:
|
|
|
178
179
|
"invalid data_structure type or invalid index")
|
|
179
180
|
|
|
180
181
|
@staticmethod
|
|
181
|
-
def
|
|
182
|
-
|
|
183
|
-
np.integer: int,
|
|
184
|
-
np.floating: float,
|
|
185
|
-
np.bool_: bool,
|
|
186
|
-
np.complexfloating: complex,
|
|
187
|
-
np.str_: str,
|
|
188
|
-
np.byte: bytes,
|
|
189
|
-
np.unicode_: str
|
|
190
|
-
}
|
|
191
|
-
for numpy_type, builtin_type in type_mapping.items():
|
|
192
|
-
if isinstance(arg, numpy_type):
|
|
193
|
-
return builtin_type(arg), type(arg).__name__
|
|
194
|
-
return arg, ''
|
|
182
|
+
def is_distributed_op(module):
|
|
183
|
+
return getattr(module, "op_is_distributed", False)
|
|
195
184
|
|
|
196
185
|
@staticmethod
|
|
197
186
|
def _analyze_builtin(arg):
|
|
@@ -217,21 +206,40 @@ class BaseDataProcessor:
|
|
|
217
206
|
return single_arg
|
|
218
207
|
|
|
219
208
|
@staticmethod
|
|
220
|
-
def _analyze_numpy(
|
|
209
|
+
def _analyze_numpy(arg):
|
|
210
|
+
return {"type": type(arg).__name__, "value": arg.item()}
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
def _analyze_ndarray(ndarray, _):
|
|
221
214
|
ndarray_json = {}
|
|
222
215
|
ndarray_json.update({'type': 'numpy.ndarray'})
|
|
223
216
|
ndarray_json.update({'dtype': str(ndarray.dtype)})
|
|
224
217
|
ndarray_json.update({'shape': ndarray.shape})
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
218
|
+
|
|
219
|
+
# 先初始化默认值
|
|
220
|
+
stats = {
|
|
221
|
+
"Max": None,
|
|
222
|
+
"Min": None,
|
|
223
|
+
"Mean": None,
|
|
224
|
+
"Norm": None
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
# 只有非空时才尝试计算
|
|
229
|
+
if ndarray.size > 0:
|
|
230
|
+
stats = {
|
|
231
|
+
"Max": np.max(ndarray).item(),
|
|
232
|
+
"Min": np.min(ndarray).item(),
|
|
233
|
+
"Mean": np.mean(ndarray).item(),
|
|
234
|
+
"Norm": np.linalg.norm(ndarray).item()
|
|
235
|
+
}
|
|
236
|
+
except Exception as e:
|
|
237
|
+
# 决定打印内容或切片
|
|
238
|
+
logger.warning(f"Error analyzing ndarray stats: {e}")
|
|
239
|
+
|
|
240
|
+
# 最后一次性更新
|
|
241
|
+
ndarray_json.update(stats)
|
|
242
|
+
|
|
235
243
|
return ndarray_json
|
|
236
244
|
|
|
237
245
|
@staticmethod
|
|
@@ -248,7 +256,7 @@ class BaseDataProcessor:
|
|
|
248
256
|
|
|
249
257
|
@classmethod
|
|
250
258
|
def get_special_types(cls):
|
|
251
|
-
return cls.
|
|
259
|
+
return cls.builtin_type + cls.np_type
|
|
252
260
|
|
|
253
261
|
@classmethod
|
|
254
262
|
def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
|
|
@@ -303,6 +311,7 @@ class BaseDataProcessor:
|
|
|
303
311
|
|
|
304
312
|
def real_hook_fn(grad):
|
|
305
313
|
return wrap_hook_fn(grad)
|
|
314
|
+
|
|
306
315
|
element.register_hook(real_hook_fn)
|
|
307
316
|
|
|
308
317
|
def if_return_forward_new_output(self):
|
|
@@ -350,6 +359,8 @@ class BaseDataProcessor:
|
|
|
350
359
|
return api_info_struct
|
|
351
360
|
|
|
352
361
|
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
362
|
+
if self.is_distributed_op(module):
|
|
363
|
+
module_input_output.update_output_with_args_and_kwargs()
|
|
353
364
|
api_info_struct = {}
|
|
354
365
|
# check whether data_mode contains forward or input
|
|
355
366
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
@@ -427,6 +438,7 @@ class BaseDataProcessor:
|
|
|
427
438
|
api_info_struct = {}
|
|
428
439
|
self.save_name = name + Const.SEP + param_name
|
|
429
440
|
data_info = self.analyze_element(grad)
|
|
441
|
+
self.save_name = None
|
|
430
442
|
grad_info_dict = {param_name: [data_info]}
|
|
431
443
|
api_info_struct[name] = grad_info_dict
|
|
432
444
|
return api_info_struct
|
|
@@ -435,10 +447,10 @@ class BaseDataProcessor:
|
|
|
435
447
|
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
|
|
436
448
|
if self.save_name is not None:
|
|
437
449
|
dump_data_name = (self.save_name + file_format)
|
|
438
|
-
self.save_name = None
|
|
439
450
|
else:
|
|
440
|
-
|
|
441
|
-
|
|
451
|
+
suffix_with_seq = (Const.SEP + suffix) if suffix else ""
|
|
452
|
+
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + suffix_with_seq +
|
|
453
|
+
file_format)
|
|
442
454
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
443
455
|
return dump_data_name, file_path
|
|
444
456
|
|
|
@@ -447,23 +459,32 @@ class BaseDataProcessor:
|
|
|
447
459
|
|
|
448
460
|
def analyze_debug_forward(self, variable, name_with_count):
|
|
449
461
|
self.current_api_or_module_name = name_with_count
|
|
450
|
-
self.api_data_category = Const.
|
|
451
|
-
# these two attributes are used to construct tensor file name {name_with_count}.
|
|
462
|
+
self.api_data_category = Const.DEBUG
|
|
463
|
+
# these two attributes are used to construct tensor file name {name_with_count}.debug.{indexes}.npy/pt
|
|
452
464
|
data_info = self.analyze_element(variable)
|
|
453
465
|
return data_info
|
|
454
466
|
|
|
455
|
-
def analyze_debug_backward(self, variable,
|
|
467
|
+
def analyze_debug_backward(self, variable, grad_name_with_count_category, nested_data_structure):
|
|
456
468
|
def hook_fn(grad, indexes):
|
|
457
469
|
suffix = Const.SEP.join([str(index) for index in indexes])
|
|
458
|
-
|
|
470
|
+
suffix_with_sep = (Const.SEP + suffix) if suffix else ""
|
|
471
|
+
self.save_name = grad_name_with_count_category + suffix_with_sep
|
|
459
472
|
grad_data_info = self.analyze_element(grad)
|
|
460
473
|
self.save_name = None
|
|
461
|
-
full_index = [
|
|
474
|
+
full_index = [grad_name_with_count_category] + indexes
|
|
462
475
|
try:
|
|
463
476
|
self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
|
|
464
477
|
except (ValueError, IndexError) as e:
|
|
465
|
-
logger.warning(f"error
|
|
466
|
-
f"skip current recording, detailed
|
|
478
|
+
logger.warning(f"error occurred while recording statistics of {grad_name_with_count_category} variable,"
|
|
479
|
+
f"skip current recording, detailed information: {e}")
|
|
467
480
|
return grad
|
|
481
|
+
|
|
468
482
|
wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
|
|
469
|
-
self.recursive_apply_transform(variable, wrap_register_hook_single_element)
|
|
483
|
+
self.recursive_apply_transform(variable, wrap_register_hook_single_element)
|
|
484
|
+
|
|
485
|
+
def _analyze_and_save_ndarray(self, ndarray, suffix):
|
|
486
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
487
|
+
save_npy(ndarray, file_path)
|
|
488
|
+
ndarray_json = BaseDataProcessor._analyze_ndarray(ndarray, suffix)
|
|
489
|
+
ndarray_json.update({"data_name": dump_data_name})
|
|
490
|
+
return ndarray_json
|