mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -50,8 +50,8 @@ else:
|
|
|
50
50
|
from msprobe.pytorch.common.utils import logger
|
|
51
51
|
from msprobe.core.common.const import Const, CompareConst
|
|
52
52
|
|
|
53
|
-
|
|
54
|
-
|
|
53
|
+
GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
|
|
54
|
+
SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
def softmax_forward(x):
|
|
@@ -166,6 +166,18 @@ def parse_bsnd_args(query, key, head_num, input_layout):
|
|
|
166
166
|
|
|
167
167
|
|
|
168
168
|
def convert_from_bnsd(_input, input_layout):
|
|
169
|
+
"""
|
|
170
|
+
transform qkv from bnsd to input_layout.
|
|
171
|
+
B: batch_size
|
|
172
|
+
S: sequence_length
|
|
173
|
+
N: num_heads
|
|
174
|
+
D: head_dim
|
|
175
|
+
Args:
|
|
176
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D)
|
|
177
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
178
|
+
Returns:
|
|
179
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
180
|
+
"""
|
|
169
181
|
if input_layout == "BSH":
|
|
170
182
|
# (B,N,S,D)=>(B,S,N*D)
|
|
171
183
|
out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
|
|
@@ -183,7 +195,19 @@ def convert_from_bnsd(_input, input_layout):
|
|
|
183
195
|
|
|
184
196
|
|
|
185
197
|
def convert_to_bnsd(_input, n, input_layout):
|
|
186
|
-
|
|
198
|
+
"""
|
|
199
|
+
transform qkv from input_layout to bnsd.
|
|
200
|
+
B: batch_size
|
|
201
|
+
S: sequence_length
|
|
202
|
+
N: num_heads
|
|
203
|
+
D: head_dim
|
|
204
|
+
Args:
|
|
205
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
206
|
+
n (int): num_heads
|
|
207
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
208
|
+
Returns:
|
|
209
|
+
tensor of shape (B,N,S,D)
|
|
210
|
+
"""
|
|
187
211
|
if input_layout == "BSH":
|
|
188
212
|
# (B,S,N*D)=>(B,N,S,D)
|
|
189
213
|
out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
|
|
@@ -199,7 +223,68 @@ def convert_to_bnsd(_input, n, input_layout):
|
|
|
199
223
|
out = _input
|
|
200
224
|
if out.dim() != 4:
|
|
201
225
|
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
202
|
-
return out.to(
|
|
226
|
+
return out.to(GTYPE)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def convert_from_bsnd(_input, input_layout):
|
|
230
|
+
"""
|
|
231
|
+
transform qkv from bsnd to input_layout.
|
|
232
|
+
B: batch_size
|
|
233
|
+
S: sequence_length
|
|
234
|
+
N: num_heads
|
|
235
|
+
D: head_dim
|
|
236
|
+
Args:
|
|
237
|
+
_input (torch.Tensor): tensor of shape (B,S,N,D)
|
|
238
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
239
|
+
Returns:
|
|
240
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
241
|
+
"""
|
|
242
|
+
if input_layout == "BSH":
|
|
243
|
+
# (B,S,N,D)=>(B,S,N*D)
|
|
244
|
+
out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
|
|
245
|
+
elif input_layout == "SBH":
|
|
246
|
+
# (B,S,N,D)=>(S,B,N*D)
|
|
247
|
+
out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
|
|
248
|
+
elif input_layout == "BNSD":
|
|
249
|
+
# (B,S,N,D)=>(B,N,S,D)
|
|
250
|
+
out = rearrange(_input, 'b s n d -> b n s d').contiguous()
|
|
251
|
+
elif input_layout == "TND":
|
|
252
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
253
|
+
else:
|
|
254
|
+
out = _input
|
|
255
|
+
return out
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def convert_to_bsnd(_input, n, input_layout):
|
|
259
|
+
"""
|
|
260
|
+
transform qkv from input_layout to bsnd.
|
|
261
|
+
B: batch_size
|
|
262
|
+
S: sequence_length
|
|
263
|
+
N: num_heads
|
|
264
|
+
D: head_dim
|
|
265
|
+
Args:
|
|
266
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
267
|
+
n (int): num_heads
|
|
268
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
269
|
+
Returns:
|
|
270
|
+
tensor of shape (B,S,N,D)
|
|
271
|
+
"""
|
|
272
|
+
if input_layout == "BSH":
|
|
273
|
+
# (B,S,N*D)=>(B,S,N,D)
|
|
274
|
+
out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
|
|
275
|
+
elif input_layout == "SBH":
|
|
276
|
+
# (S,B,N*D)=>(B,S,N,D)
|
|
277
|
+
out = rearrange(_input, 's b (n d) -> b s n d', n=n)
|
|
278
|
+
elif input_layout == "BNSD":
|
|
279
|
+
# (B,N,S,D)=>(B,S,N,D)
|
|
280
|
+
out = rearrange(_input, 'b n s d -> b s n d', n=n)
|
|
281
|
+
elif input_layout == "TND":
|
|
282
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
283
|
+
else:
|
|
284
|
+
out = _input
|
|
285
|
+
if out.dim() != 4:
|
|
286
|
+
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
287
|
+
return out
|
|
203
288
|
|
|
204
289
|
|
|
205
290
|
def generate_atten_mask(*args):
|
|
@@ -279,7 +364,7 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
|
|
|
279
364
|
"""
|
|
280
365
|
logger.info("Using QKV to rebuild original softmax")
|
|
281
366
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
282
|
-
softmax_res,
|
|
367
|
+
softmax_res, _, _ = softmax_forward(qk)
|
|
283
368
|
return softmax_res
|
|
284
369
|
|
|
285
370
|
|
|
@@ -319,6 +404,10 @@ def get_input_layout(*args, **kwargs):
|
|
|
319
404
|
|
|
320
405
|
|
|
321
406
|
def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
407
|
+
|
|
408
|
+
if len(args) < 2:
|
|
409
|
+
raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
|
|
410
|
+
|
|
322
411
|
# query, key, value, head_num, input_layout
|
|
323
412
|
head_num = get_head_num(*args, **kwargs)
|
|
324
413
|
input_layout = get_input_layout(*args, **kwargs)
|
|
@@ -454,7 +543,7 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
454
543
|
value = convert_to_bnsd(value, n2, input_layout)
|
|
455
544
|
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
456
545
|
|
|
457
|
-
if
|
|
546
|
+
if SOFTMAX_BUILD_MODE == "QKV":
|
|
458
547
|
softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
|
|
459
548
|
else:
|
|
460
549
|
softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
|
|
@@ -531,8 +620,13 @@ def gpu_fusion_attention(*args, **kwargs):
|
|
|
531
620
|
else:
|
|
532
621
|
alibi_slopes = None
|
|
533
622
|
|
|
623
|
+
input_layout = get_input_layout(*args, **kwargs)
|
|
624
|
+
query = convert_to_bsnd(query, n1, input_layout)
|
|
625
|
+
key = convert_to_bsnd(key, n2, input_layout)
|
|
626
|
+
value = convert_to_bsnd(value, n2, input_layout)
|
|
534
627
|
out = flash_attn_func(
|
|
535
628
|
query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
|
|
536
629
|
window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
|
|
537
630
|
)
|
|
631
|
+
out = convert_from_bsnd(out, input_layout)
|
|
538
632
|
return out, Const.NONE, Const.NONE
|
|
@@ -40,6 +40,9 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
|
|
|
40
40
|
x_shape = x.shape
|
|
41
41
|
h = x.float()
|
|
42
42
|
grad = dy_tensor.float()
|
|
43
|
+
if len(r1_shape) < 4 or len(x_shape) < 4:
|
|
44
|
+
raise RuntimeError(f"Shape of r1 and x should at least be 4-dimension, "
|
|
45
|
+
f"but got r1 shape:{r1_shape}, x shape:{x_shape}")
|
|
43
46
|
condition_1 = (r1_shape[0] == 1
|
|
44
47
|
and r1_shape[1] == x_shape[1]
|
|
45
48
|
and r1_shape[2] == 1
|
|
@@ -68,4 +71,5 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
|
|
|
68
71
|
for j in range(x_shape[2]):
|
|
69
72
|
r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
|
|
70
73
|
r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
|
|
74
|
+
|
|
71
75
|
return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
|
|
@@ -19,7 +19,11 @@ import torch
|
|
|
19
19
|
def npu_swiglu(x, dim=-1):
|
|
20
20
|
tensor_dtype = x.dtype
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
try:
|
|
23
|
+
in_tensors = torch.chunk(x, 2, dim=dim)
|
|
24
|
+
except Exception as e:
|
|
25
|
+
raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
|
|
26
|
+
|
|
23
27
|
if tensor_dtype == torch.float32:
|
|
24
28
|
tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
|
|
25
29
|
output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
|
|
@@ -34,7 +38,11 @@ def npu_swiglu(x, dim=-1):
|
|
|
34
38
|
|
|
35
39
|
def npu_swiglu_backward(grad, x, dim=-1):
|
|
36
40
|
tensor_dtype = grad.dtype
|
|
37
|
-
|
|
41
|
+
try:
|
|
42
|
+
in_tensors = torch.chunk(x, 2, dim=dim)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
|
|
45
|
+
|
|
38
46
|
tensor_grad_out = grad
|
|
39
47
|
|
|
40
48
|
if tensor_dtype == torch.float16:
|
|
@@ -13,20 +13,20 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
|
-
|
|
18
16
|
from msprobe.core.common.exceptions import ParseJsonException
|
|
19
|
-
from msprobe.core.common.file_utils import
|
|
17
|
+
from msprobe.core.common.file_utils import load_json
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
def parse_json_info_forward_backward(json_path):
|
|
23
|
-
|
|
24
|
-
dump_json = json.load(f)
|
|
22
|
+
dump_json = load_json(json_path)
|
|
25
23
|
|
|
26
24
|
real_data_path = dump_json.get("dump_data_dir")
|
|
27
25
|
dump_data = dump_json.get("data")
|
|
26
|
+
if dump_data is None:
|
|
27
|
+
raise ParseJsonException(ParseJsonException.InvalidDumpJson, "something wrong with dump, no data found in dump.json")
|
|
28
28
|
if not dump_data:
|
|
29
|
-
|
|
29
|
+
logger.warning("data field is empty, no overflow data found.")
|
|
30
30
|
|
|
31
31
|
forward_data = {}
|
|
32
32
|
backward_data = {}
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import io
|
|
17
17
|
import os
|
|
18
|
+
import pickle
|
|
18
19
|
import random
|
|
19
20
|
import stat
|
|
20
21
|
from functools import wraps
|
|
@@ -24,7 +25,7 @@ import torch
|
|
|
24
25
|
import torch.distributed as dist
|
|
25
26
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
26
27
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
27
|
-
check_file_or_directory_path, check_path_before_create)
|
|
28
|
+
check_file_or_directory_path, check_path_before_create, FileOpen)
|
|
28
29
|
from msprobe.core.common.log import logger
|
|
29
30
|
from msprobe.core.common.utils import check_seed_all
|
|
30
31
|
from packaging import version
|
|
@@ -75,7 +76,7 @@ def parameter_adapter(func):
|
|
|
75
76
|
else:
|
|
76
77
|
res = [input_tensor[tensor_index] for tensor_index in indices]
|
|
77
78
|
return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
|
|
78
|
-
if self.op_name_ == "__eq__" and args[1] is None:
|
|
79
|
+
if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
|
|
79
80
|
return False
|
|
80
81
|
return func(self, *args, **kwargs)
|
|
81
82
|
|
|
@@ -269,17 +270,17 @@ def load_pt(pt_path, to_cpu=False):
|
|
|
269
270
|
check_file_or_directory_path(pt_path)
|
|
270
271
|
try:
|
|
271
272
|
if to_cpu:
|
|
272
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
273
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
|
|
273
274
|
else:
|
|
274
|
-
pt = torch.load(pt_path)
|
|
275
|
+
pt = torch.load(pt_path, weights_only=True)
|
|
275
276
|
except Exception as e:
|
|
276
277
|
raise RuntimeError(f"load pt file {pt_path} failed") from e
|
|
277
278
|
return pt
|
|
278
279
|
|
|
279
280
|
|
|
280
281
|
def save_pt(tensor, filepath):
|
|
281
|
-
filepath = os.path.realpath(filepath)
|
|
282
282
|
check_path_before_create(filepath)
|
|
283
|
+
filepath = os.path.realpath(filepath)
|
|
283
284
|
try:
|
|
284
285
|
torch.save(tensor, filepath)
|
|
285
286
|
except Exception as e:
|
|
@@ -290,6 +291,56 @@ def save_pt(tensor, filepath):
|
|
|
290
291
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
291
292
|
|
|
292
293
|
|
|
294
|
+
class TypeCheckingUnpickler(pickle.Unpickler):
|
|
295
|
+
"""
|
|
296
|
+
This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
|
|
297
|
+
It overrides the find_class method to add type checking functionality.
|
|
298
|
+
"""
|
|
299
|
+
allowed_types = [
|
|
300
|
+
"str",
|
|
301
|
+
"ApiData",
|
|
302
|
+
"OrderedDict",
|
|
303
|
+
"_rebuild_tensor_v2", # from torch.utils
|
|
304
|
+
"_load_from_bytes" # from torch.storage
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
def find_class(self, module, name):
|
|
308
|
+
"""
|
|
309
|
+
Method to find the class of the object to be unpickled.
|
|
310
|
+
Throws pickle.UnpicklingError If the object type is not in the allowed types list.
|
|
311
|
+
"""
|
|
312
|
+
if name in self.allowed_types:
|
|
313
|
+
return super().find_class(module, name)
|
|
314
|
+
raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def save_pkl(tensor, filepath):
|
|
318
|
+
"""Save ApiData or str objection by pickle"""
|
|
319
|
+
check_path_before_create(filepath)
|
|
320
|
+
filepath = os.path.realpath(filepath)
|
|
321
|
+
try:
|
|
322
|
+
with FileOpen(filepath, 'wb') as f:
|
|
323
|
+
pickle.dump(tensor, f)
|
|
324
|
+
except Exception as e:
|
|
325
|
+
logger.error("Save pt file failed, please check according possible error causes: "
|
|
326
|
+
"1. out of disk space or disk error, "
|
|
327
|
+
"2. no permission to write files, etc.")
|
|
328
|
+
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
329
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def load_pkl(pt_path):
|
|
333
|
+
"""Load ApiData or str objection by pickle for accuracy_checker_online"""
|
|
334
|
+
check_file_or_directory_path(pt_path)
|
|
335
|
+
pt_path = os.path.realpath(pt_path)
|
|
336
|
+
try:
|
|
337
|
+
with FileOpen(pt_path, 'rb') as f:
|
|
338
|
+
pt = TypeCheckingUnpickler(f).load()
|
|
339
|
+
except Exception as e:
|
|
340
|
+
raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
|
|
341
|
+
return pt
|
|
342
|
+
|
|
343
|
+
|
|
293
344
|
def save_api_data(api_data):
|
|
294
345
|
"""Save data to io stream"""
|
|
295
346
|
try:
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
from msprobe.core.common.utils import CompareException, check_compare_param, \
|
|
18
|
-
check_configuration_param,
|
|
18
|
+
check_configuration_param, set_dump_path, get_dump_mode
|
|
19
19
|
from msprobe.core.common.file_utils import create_directory
|
|
20
20
|
from msprobe.core.common.exceptions import FileCheckException
|
|
21
21
|
from msprobe.pytorch.common.log import logger
|
|
@@ -30,6 +30,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
30
30
|
stack_mode = kwargs.get('stack_mode', False)
|
|
31
31
|
auto_analyze = kwargs.get('auto_analyze', True)
|
|
32
32
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
33
|
+
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
33
34
|
# get the ranks and match by order
|
|
34
35
|
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
35
36
|
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
@@ -49,18 +50,16 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
49
50
|
'npu_json_path': npu_path,
|
|
50
51
|
'bench_json_path': bench_path,
|
|
51
52
|
'stack_json_path': stack_path,
|
|
52
|
-
'is_print_compare_log':
|
|
53
|
+
'is_print_compare_log': is_print_compare_log
|
|
53
54
|
}
|
|
54
55
|
try:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
56
|
+
set_dump_path(dump_result_param)
|
|
57
|
+
dump_mode = get_dump_mode(dump_result_param)
|
|
58
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, is_print_compare_log)
|
|
58
59
|
create_directory(output_path)
|
|
59
|
-
check_compare_param(dump_result_param, output_path,
|
|
60
|
-
summary_compare=summary_compare, md5_compare=md5_compare)
|
|
60
|
+
check_compare_param(dump_result_param, output_path, dump_mode)
|
|
61
61
|
except (CompareException, FileCheckException) as error:
|
|
62
62
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
63
63
|
raise CompareException(error.code) from error
|
|
64
64
|
pt_comparator = PTComparator()
|
|
65
|
-
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
|
|
66
|
-
summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
|
|
65
|
+
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', dump_mode=dump_mode, **kwargs)
|
|
@@ -19,8 +19,8 @@ from msprobe.core.common.const import FileCheckConst
|
|
|
19
19
|
from msprobe.pytorch.common.log import logger
|
|
20
20
|
from msprobe.core.common.exceptions import FileCheckException
|
|
21
21
|
from msprobe.core.compare.acc_compare import Comparator
|
|
22
|
-
from msprobe.core.common.utils import check_configuration_param,
|
|
23
|
-
CompareException
|
|
22
|
+
from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
|
|
23
|
+
CompareException, set_dump_path, get_dump_mode
|
|
24
24
|
from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
|
|
25
25
|
from msprobe.pytorch.common.utils import load_pt
|
|
26
26
|
|
|
@@ -45,6 +45,8 @@ class PTComparator (Comparator):
|
|
|
45
45
|
return mapping_dict
|
|
46
46
|
|
|
47
47
|
def read_npy_data(self, dir_path, file_name):
|
|
48
|
+
if not file_name:
|
|
49
|
+
return None
|
|
48
50
|
data_path = os.path.join(dir_path, file_name)
|
|
49
51
|
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
50
52
|
FileCheckConst.PT_SUFFIX, False)
|
|
@@ -68,15 +70,15 @@ class PTComparator (Comparator):
|
|
|
68
70
|
|
|
69
71
|
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
|
|
70
72
|
try:
|
|
71
|
-
|
|
73
|
+
set_dump_path(input_param)
|
|
74
|
+
dump_mode = get_dump_mode(input_param)
|
|
72
75
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
73
76
|
create_directory(output_path)
|
|
74
|
-
check_compare_param(input_param, output_path,
|
|
77
|
+
check_compare_param(input_param, output_path, dump_mode)
|
|
75
78
|
data_mapping = kwargs.get('data_mapping', None)
|
|
76
79
|
except (CompareException, FileCheckException) as error:
|
|
77
80
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
78
81
|
raise CompareException(error.code) from error
|
|
79
82
|
pt_comparator = PTComparator(data_mapping)
|
|
80
83
|
pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
81
|
-
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match,
|
|
82
|
-
md5_compare=md5_compare)
|
|
84
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
|
|
@@ -31,14 +31,14 @@ class DebuggerConfig:
|
|
|
31
31
|
self.scope = task_config.scope if task_config.scope else []
|
|
32
32
|
self.list = task_config.list if task_config.list else []
|
|
33
33
|
self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
|
|
34
|
-
self.backward_input_list = task_config.backward_input if task_config.backward_input else []
|
|
35
|
-
self.backward_input = {}
|
|
36
|
-
self.acl_config = common_config.acl_config if common_config.acl_config else ""
|
|
37
|
-
self.is_forward_acl_dump = True
|
|
38
34
|
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
39
35
|
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
40
36
|
self.framework = Const.PT_FRAMEWORK
|
|
41
37
|
|
|
38
|
+
if self.level == Const.LEVEL_L2:
|
|
39
|
+
self.is_backward_kernel_dump = False
|
|
40
|
+
self._check_and_adjust_config_with_l2()
|
|
41
|
+
|
|
42
42
|
if self.task == Const.FREE_BENCHMARK:
|
|
43
43
|
self.fuzz_device = task_config.fuzz_device
|
|
44
44
|
self.handler_type = task_config.handler_type
|
|
@@ -59,20 +59,11 @@ class DebuggerConfig:
|
|
|
59
59
|
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
60
60
|
self.host = task_config.host if task_config.host else ""
|
|
61
61
|
self.port = task_config.port if task_config.port else -1
|
|
62
|
+
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
63
|
+
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
62
64
|
|
|
63
65
|
self.check()
|
|
64
66
|
|
|
65
|
-
if self.level == "L2":
|
|
66
|
-
if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
|
|
67
|
-
raise ValueError("scope must be configured as a list with one api name")
|
|
68
|
-
if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
|
|
69
|
-
raise ValueError("backward_input must be configured when scope contains 'backward'")
|
|
70
|
-
if Const.BACKWARD in self.scope[0]:
|
|
71
|
-
self.is_forward_acl_dump = False
|
|
72
|
-
for index, scope_spec in enumerate(self.scope):
|
|
73
|
-
self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
|
|
74
|
-
self.backward_input[self.scope[index]] = self.backward_input_list[index]
|
|
75
|
-
|
|
76
67
|
def check_kwargs(self):
|
|
77
68
|
if self.task and self.task not in Const.TASK_LIST:
|
|
78
69
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
@@ -106,3 +97,16 @@ class DebuggerConfig:
|
|
|
106
97
|
logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
|
|
107
98
|
raise MsprobeException(
|
|
108
99
|
MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
|
|
100
|
+
|
|
101
|
+
def _check_and_adjust_config_with_l2(self):
|
|
102
|
+
if self.scope:
|
|
103
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
104
|
+
f"When level is set to L2, the scope cannot be configured.")
|
|
105
|
+
if not self.list or len(self.list) != 1:
|
|
106
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
107
|
+
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
108
|
+
api_name = self.list[0]
|
|
109
|
+
if api_name.endswith(Const.BACKWARD):
|
|
110
|
+
self.is_backward_kernel_dump = True
|
|
111
|
+
api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
|
|
112
|
+
self.list.append(api_forward_name)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.file_utils import save_json
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_kernel_config_json(dump_path, cur_rank):
|
|
22
|
+
kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
|
|
23
|
+
kernel_config_path = os.path.join(dump_path, kernel_config_name)
|
|
24
|
+
config_info = {
|
|
25
|
+
"dump": {
|
|
26
|
+
"dump_list": [],
|
|
27
|
+
"dump_path": dump_path,
|
|
28
|
+
"dump_mode": "all",
|
|
29
|
+
"dump_op_switch": "on"
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
save_json(kernel_config_path, config_info, indent=4)
|
|
33
|
+
return kernel_config_path
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Dict
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from collections import defaultdict
|
|
2
17
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.const import Const
|
|
2
17
|
|
|
3
18
|
|
|
@@ -17,6 +17,7 @@ from dataclasses import dataclass
|
|
|
17
17
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
20
21
|
from msprobe.pytorch.free_benchmark import logger
|
|
21
22
|
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
22
23
|
DeviceType,
|
|
@@ -128,7 +129,13 @@ def make_unequal_row(
|
|
|
128
129
|
row.max_rel = ratio - 1
|
|
129
130
|
origin_tensor = data_params.original_result
|
|
130
131
|
perturbed_tensor = data_params.perturbed_result
|
|
131
|
-
if index:
|
|
132
|
+
if index is not None:
|
|
133
|
+
if index >= len(origin_tensor) or index >= len(perturbed_tensor):
|
|
134
|
+
err_msg = f"When generating unequal results, index {index} of output is out of bounds. please check!"
|
|
135
|
+
raise FreeBenchmarkException(
|
|
136
|
+
FreeBenchmarkException.OutputIndexError,
|
|
137
|
+
error_info=err_msg,
|
|
138
|
+
)
|
|
132
139
|
origin_tensor = origin_tensor[index]
|
|
133
140
|
perturbed_tensor = perturbed_tensor[index]
|
|
134
141
|
row.output_index = index
|
|
@@ -13,7 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
|
|
16
17
|
import torch
|
|
18
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
19
|
+
from msprobe.core.common.utils import recursion_depth_decorator
|
|
17
20
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
18
21
|
|
|
19
22
|
|
|
@@ -51,6 +54,7 @@ class Tools:
|
|
|
51
54
|
return api_name.rsplit(".", 2)[0]
|
|
52
55
|
|
|
53
56
|
@staticmethod
|
|
57
|
+
@recursion_depth_decorator("FreeBenchmark: Tools.convert_device_and_dtype")
|
|
54
58
|
def convert_device_and_dtype(
|
|
55
59
|
tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
|
|
56
60
|
):
|
|
@@ -73,23 +77,41 @@ class Tools:
|
|
|
73
77
|
return tensor_seq
|
|
74
78
|
|
|
75
79
|
@staticmethod
|
|
80
|
+
@recursion_depth_decorator("FreeBenchmark: Tools.convert_fuzz_output_to_origin")
|
|
76
81
|
def convert_fuzz_output_to_origin(origin, perturbed):
|
|
77
|
-
if isinstance(origin, torch.Tensor):
|
|
82
|
+
if isinstance(origin, torch.Tensor) and isinstance(perturbed, torch.Tensor):
|
|
78
83
|
origin.data = perturbed.to(origin.dtype).to(origin.device)
|
|
79
84
|
return origin
|
|
80
|
-
if isinstance(origin, dict):
|
|
85
|
+
if isinstance(origin, dict) and isinstance(perturbed, dict):
|
|
81
86
|
output = dict()
|
|
82
87
|
for key, value in origin.items():
|
|
88
|
+
if key not in perturbed:
|
|
89
|
+
err_msg = f"'{key}' not in perturbed output."
|
|
90
|
+
raise FreeBenchmarkException(
|
|
91
|
+
FreeBenchmarkException.InvalidPerturbedOutput,
|
|
92
|
+
error_info=err_msg,
|
|
93
|
+
)
|
|
83
94
|
output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
|
|
84
95
|
return output
|
|
85
|
-
if isinstance(origin, (tuple, list)):
|
|
96
|
+
if isinstance(origin, (tuple, list)) and isinstance(perturbed, (tuple, list)):
|
|
86
97
|
result = list()
|
|
98
|
+
if len(perturbed) != len(origin):
|
|
99
|
+
err_msg = (
|
|
100
|
+
f"length of perturbed output ({len(perturbed)}) is different "
|
|
101
|
+
f"from the length of original output ({len(origin)})."
|
|
102
|
+
)
|
|
103
|
+
raise FreeBenchmarkException(
|
|
104
|
+
FreeBenchmarkException.InvalidPerturbedOutput, error_info=err_msg
|
|
105
|
+
)
|
|
87
106
|
for index_, value in enumerate(origin):
|
|
88
107
|
result.append(
|
|
89
108
|
Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
|
|
90
109
|
)
|
|
91
110
|
return type(origin)(result)
|
|
92
|
-
|
|
111
|
+
err_msg = f"conversion of two outputs with types ({type(origin)}, {type(perturbed)}) is not supported."
|
|
112
|
+
raise FreeBenchmarkException(
|
|
113
|
+
FreeBenchmarkException.UnsupportedType, error_info=err_msg
|
|
114
|
+
)
|
|
93
115
|
|
|
94
116
|
|
|
95
117
|
class TorchC:
|