mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -12,13 +12,16 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
from datetime import datetime
|
|
16
18
|
from mindspore import dtype as mstype, Tensor
|
|
17
19
|
|
|
18
20
|
from msprobe.mindspore.monitor.features import FUNC_MAP
|
|
19
21
|
from msprobe.core.common.const import MonitorConst
|
|
20
22
|
from msprobe.core.common.utils import is_int
|
|
21
23
|
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
def get_single_metrics(op_list, tag, tensor, output=None):
|
|
@@ -95,8 +98,8 @@ def validate_ranks(ranks):
|
|
|
95
98
|
if not isinstance(ranks, list):
|
|
96
99
|
raise TypeError("module_ranks should be a list")
|
|
97
100
|
for rank in ranks:
|
|
98
|
-
if not isinstance(rank,
|
|
99
|
-
raise TypeError(f"element in module_ranks should be a
|
|
101
|
+
if not isinstance(rank, int):
|
|
102
|
+
raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
|
|
100
103
|
|
|
101
104
|
|
|
102
105
|
def validate_targets(targets):
|
|
@@ -209,6 +212,11 @@ def validate_collect_times(collect_times):
|
|
|
209
212
|
raise ValueError("collect_times must greater than 1")
|
|
210
213
|
|
|
211
214
|
|
|
215
|
+
def validate_dynamic_on(dynamic_on):
|
|
216
|
+
if not isinstance(dynamic_on, bool):
|
|
217
|
+
raise TypeError('dynamic_on should be a bool')
|
|
218
|
+
|
|
219
|
+
|
|
212
220
|
def validate_config(config):
|
|
213
221
|
config['ops'] = validate_ops(config.get('ops', []))
|
|
214
222
|
|
|
@@ -255,9 +263,12 @@ def validate_config(config):
|
|
|
255
263
|
step_interval = config.get('step_interval', 1)
|
|
256
264
|
validate_step_interval(step_interval)
|
|
257
265
|
|
|
258
|
-
collect_times = config.get('collect_times', 1e8)
|
|
266
|
+
collect_times = config.get('collect_times', int(1e8))
|
|
259
267
|
validate_collect_times(collect_times)
|
|
260
268
|
|
|
269
|
+
dynamic_on = config.get('dynamic_on', False)
|
|
270
|
+
validate_dynamic_on(dynamic_on)
|
|
271
|
+
|
|
261
272
|
if not targets:
|
|
262
273
|
if xy_distribution:
|
|
263
274
|
config["all_xy"] = True
|
|
@@ -265,3 +276,34 @@ def validate_config(config):
|
|
|
265
276
|
config["is_select"] = False
|
|
266
277
|
else:
|
|
267
278
|
config["is_select"] = True
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def time_str2time_digit(time_str):
|
|
282
|
+
time_format = '%b%d_%H-%M-%S'
|
|
283
|
+
try:
|
|
284
|
+
time_digit = datetime.strptime(time_str, time_format)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
|
|
287
|
+
of existing output dirpath, like 'Dec03_21-34-40'.") from e
|
|
288
|
+
return time_digit
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def get_target_output_dir(monitor_path, time_start, time_end):
|
|
292
|
+
check_file_or_directory_path(monitor_path, isdir=True)
|
|
293
|
+
time_start = time_str2time_digit(time_start) if time_start is not None else time_start
|
|
294
|
+
time_end = time_str2time_digit(time_end) if time_end is not None else time_end
|
|
295
|
+
if time_start and time_end and time_start > time_end:
|
|
296
|
+
raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
|
|
297
|
+
result = {}
|
|
298
|
+
for dirname in os.listdir(monitor_path):
|
|
299
|
+
match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
|
|
300
|
+
if not match:
|
|
301
|
+
continue
|
|
302
|
+
time_tag = match.group(1)
|
|
303
|
+
rank = match.group(2)
|
|
304
|
+
target_time = time_str2time_digit(time_tag)
|
|
305
|
+
start_ok = time_start is None or target_time >= time_start
|
|
306
|
+
end_ok = time_end is None or target_time <= time_end
|
|
307
|
+
if start_ok and end_ok:
|
|
308
|
+
result[rank] = os.path.join(monitor_path, dirname)
|
|
309
|
+
return result
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from msprobe.core.common.log import logger
|
|
16
17
|
from msprobe.mindspore.common.const import Const
|
|
17
18
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
18
19
|
from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
|
|
@@ -44,6 +45,7 @@ class OverflowCheckToolFactory:
|
|
|
44
45
|
raise Exception("Valid level is needed.")
|
|
45
46
|
tool = tool.get(config.execution_mode)
|
|
46
47
|
if not tool:
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
logger.error(f"Overflow check is not supported in {config.execution_mode} mode "
|
|
49
|
+
f"when level is {config.level}.")
|
|
50
|
+
raise ValueError
|
|
49
51
|
return tool(config)
|
msprobe/mindspore/service.py
CHANGED
|
@@ -41,7 +41,7 @@ from msprobe.mindspore.cell_processor import CellProcessor
|
|
|
41
41
|
from msprobe.mindspore.common.log import logger
|
|
42
42
|
from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
|
|
43
43
|
is_mindtorch, register_backward_hook_functions)
|
|
44
|
-
from msprobe.mindspore.dump.hook_cell.
|
|
44
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
45
45
|
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
46
46
|
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
47
47
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
@@ -63,6 +63,8 @@ class Service:
|
|
|
63
63
|
self.inner_switch = False
|
|
64
64
|
self.primitive_switch = False
|
|
65
65
|
self.current_iter = 0
|
|
66
|
+
self.loop = 0
|
|
67
|
+
self.init_step = 0
|
|
66
68
|
self.first_start = True
|
|
67
69
|
self.current_rank = None
|
|
68
70
|
self.dump_iter_dir = None
|
|
@@ -71,6 +73,7 @@ class Service:
|
|
|
71
73
|
self.params_grad_info = {}
|
|
72
74
|
self.hook_handle_dict = {}
|
|
73
75
|
# 提前注册,确保注册尽可能多的API hook
|
|
76
|
+
self.api_register = get_api_register()
|
|
74
77
|
self.register_api_hook()
|
|
75
78
|
self.init_for_debug_level()
|
|
76
79
|
|
|
@@ -276,11 +279,24 @@ class Service:
|
|
|
276
279
|
if self.config.task == Const.TENSOR:
|
|
277
280
|
self.data_collector.data_processor.dump_async_data()
|
|
278
281
|
self.data_collector.write_json()
|
|
279
|
-
self.
|
|
280
|
-
self.data_collector.update_iter(self.current_iter)
|
|
282
|
+
self.loop += 1
|
|
281
283
|
self.reset_status()
|
|
282
284
|
|
|
283
285
|
def start(self, model=None):
|
|
286
|
+
if self.current_iter == 0:
|
|
287
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
288
|
+
JitDump.set_config(self.config)
|
|
289
|
+
JitDump.set_data_collector(self.data_collector)
|
|
290
|
+
if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
|
|
291
|
+
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
292
|
+
else:
|
|
293
|
+
ms.common.api._JitExecutor = JitDump
|
|
294
|
+
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
295
|
+
if pijit_label:
|
|
296
|
+
PIJitCaptureContext.__enter__ = self.empty
|
|
297
|
+
PIJitCaptureContext.__exit__ = self.empty
|
|
298
|
+
self.current_iter = self.loop + self.init_step
|
|
299
|
+
self.data_collector.update_iter(self.current_iter)
|
|
284
300
|
if self.config.level == Const.LEVEL_DEBUG:
|
|
285
301
|
return
|
|
286
302
|
self.start_call = True
|
|
@@ -293,6 +309,7 @@ class Service:
|
|
|
293
309
|
print_tools_ends_info()
|
|
294
310
|
return
|
|
295
311
|
if self.config.step and self.current_iter not in self.config.step:
|
|
312
|
+
JitDump.jit_dump_switch = False
|
|
296
313
|
return
|
|
297
314
|
self.model = self.check_model_valid(model)
|
|
298
315
|
|
|
@@ -308,20 +325,9 @@ class Service:
|
|
|
308
325
|
return
|
|
309
326
|
self.register_primitive_hook()
|
|
310
327
|
self.register_cell_hook()
|
|
311
|
-
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
312
|
-
JitDump.set_config(self.config)
|
|
313
|
-
JitDump.set_data_collector(self.data_collector)
|
|
314
|
-
if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
|
|
315
|
-
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
316
|
-
else:
|
|
317
|
-
ms.common.api._JitExecutor = JitDump
|
|
318
|
-
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
319
|
-
if pijit_label:
|
|
320
|
-
PIJitCaptureContext.__enter__ = self.empty
|
|
321
|
-
PIJitCaptureContext.__exit__ = self.empty
|
|
322
328
|
self.first_start = False
|
|
323
329
|
|
|
324
|
-
api_register.
|
|
330
|
+
self.api_register.register_all_api()
|
|
325
331
|
self.switch = True
|
|
326
332
|
self.primitive_switch = True
|
|
327
333
|
logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
@@ -410,8 +416,8 @@ class Service:
|
|
|
410
416
|
def register_api_hook(self):
|
|
411
417
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
412
418
|
logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
413
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
414
|
-
api_register.
|
|
419
|
+
self.api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
420
|
+
self.api_register.register_all_api()
|
|
415
421
|
|
|
416
422
|
def get_cells_and_names(self):
|
|
417
423
|
cells_and_names_with_index = {}
|
|
@@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat
|
|
|
40
40
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
|
|
41
41
|
from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
|
|
42
42
|
from msprobe.pytorch.common.log import logger
|
|
43
|
-
from msprobe.core.common.utils import CompareException
|
|
43
|
+
from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
|
|
44
44
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
45
45
|
|
|
46
46
|
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
|
|
@@ -151,6 +151,7 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
151
151
|
message = ''
|
|
152
152
|
compare_column = ApiPrecisionOutputColumn()
|
|
153
153
|
full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
|
|
154
|
+
check_op_str_pattern_valid(full_api_name_with_direction_status)
|
|
154
155
|
row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
|
|
155
156
|
api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
|
|
156
157
|
if not api_full_name:
|
|
@@ -430,6 +431,7 @@ def _api_precision_compare(parser=None):
|
|
|
430
431
|
_api_precision_compare_parser(parser)
|
|
431
432
|
args = parser.parse_args(sys.argv[1:])
|
|
432
433
|
_api_precision_compare_command(args)
|
|
434
|
+
logger.info("Compare task completed.")
|
|
433
435
|
|
|
434
436
|
|
|
435
437
|
def _api_precision_compare_command(args):
|
|
@@ -457,8 +459,3 @@ def _api_precision_compare_parser(parser):
|
|
|
457
459
|
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
458
460
|
help="<optional> The api precision compare task result out path.",
|
|
459
461
|
required=False)
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
if __name__ == '__main__':
|
|
463
|
-
_api_precision_compare()
|
|
464
|
-
logger.info("Compare task completed.")
|
|
@@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st
|
|
|
28
28
|
ulp_standard_api, thousandth_standard_api
|
|
29
29
|
from msprobe.core.common.file_utils import FileOpen, load_json, save_json
|
|
30
30
|
from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
|
|
31
|
-
from msprobe.core.common.const import Const, MonitorConst, MsgConst
|
|
31
|
+
from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
|
|
32
32
|
from msprobe.core.common.log import logger
|
|
33
|
-
from msprobe.core.common.file_utils import make_dir
|
|
34
|
-
from msprobe.core.common.
|
|
33
|
+
from msprobe.core.common.file_utils import make_dir, change_mode
|
|
34
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
35
35
|
|
|
36
36
|
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
37
37
|
TORCH_BOOL_TYPE = ["torch.bool"]
|
|
@@ -50,6 +50,7 @@ DATA_NAME = "data_name"
|
|
|
50
50
|
API_MAX_LENGTH = 30
|
|
51
51
|
PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
|
|
52
52
|
DATAMODE_LIST = ["random_data", "real_data"]
|
|
53
|
+
ITER_MAX_TIMES = 1000
|
|
53
54
|
|
|
54
55
|
|
|
55
56
|
class APIInfo:
|
|
@@ -97,6 +98,8 @@ class CommonConfig:
|
|
|
97
98
|
iter_t = self.iter_times
|
|
98
99
|
if iter_t <= 0:
|
|
99
100
|
raise ValueError("iter_times should be an integer bigger than zero!")
|
|
101
|
+
if iter_t > ITER_MAX_TIMES:
|
|
102
|
+
raise ValueError("iter_times should not be greater than 1000!")
|
|
100
103
|
|
|
101
104
|
json_file = self.extract_api_path
|
|
102
105
|
propagation = self.propagation
|
|
@@ -117,7 +120,7 @@ class CommonConfig:
|
|
|
117
120
|
|
|
118
121
|
# Retrieve the first API name and dictionary
|
|
119
122
|
forward_item = next(iter(json_content.items()), None)
|
|
120
|
-
if not forward_item or not isinstance(forward_item[1], dict):
|
|
123
|
+
if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
|
|
121
124
|
raise ValueError(f'Invalid forward API data in json_content!')
|
|
122
125
|
|
|
123
126
|
# if propagation is backward, ensure json file contains forward and backward info
|
|
@@ -127,7 +130,7 @@ class CommonConfig:
|
|
|
127
130
|
# if propagation is backward, ensure it has valid data
|
|
128
131
|
if propagation == Const.BACKWARD:
|
|
129
132
|
backward_item = list(json_content.items())[1]
|
|
130
|
-
if not isinstance(backward_item[1], dict):
|
|
133
|
+
if not isinstance(backward_item[1], dict) or not backward_item[1]:
|
|
131
134
|
raise ValueError(f'Invalid backward API data in json_content!')
|
|
132
135
|
|
|
133
136
|
return json_content
|
|
@@ -169,7 +172,7 @@ class APIExtractor:
|
|
|
169
172
|
value = self.load_real_data_path(value, real_data_path)
|
|
170
173
|
new_data[key] = value
|
|
171
174
|
if not new_data:
|
|
172
|
-
logger.
|
|
175
|
+
logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
|
|
173
176
|
else:
|
|
174
177
|
save_json(self.output_file, new_data, indent=4)
|
|
175
178
|
logger.info(
|
|
@@ -183,6 +186,7 @@ class APIExtractor:
|
|
|
183
186
|
self.update_data_name(v, dump_data_dir)
|
|
184
187
|
return value
|
|
185
188
|
|
|
189
|
+
@recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
|
|
186
190
|
def update_data_name(self, data, dump_data_dir):
|
|
187
191
|
if isinstance(data, list):
|
|
188
192
|
for item in data:
|
|
@@ -467,6 +471,7 @@ def _run_operator_generate_commond(cmd_args):
|
|
|
467
471
|
fout.write(code_template.format(**internal_settings))
|
|
468
472
|
except OSError:
|
|
469
473
|
logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
|
|
474
|
+
change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
470
475
|
|
|
471
476
|
logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
|
|
472
477
|
|
|
@@ -37,9 +37,9 @@ def load_pt(pt_path, to_cpu=False):
|
|
|
37
37
|
pt_path = os.path.realpath(pt_path)
|
|
38
38
|
try:
|
|
39
39
|
if to_cpu:
|
|
40
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
40
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
|
|
41
41
|
else:
|
|
42
|
-
pt = torch.load(pt_path)
|
|
42
|
+
pt = torch.load(pt_path, weights_only=True)
|
|
43
43
|
except Exception as e:
|
|
44
44
|
raise RuntimeError(f"load pt file {{pt_path}} failed") from e
|
|
45
45
|
return pt
|
|
@@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
50
50
|
backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
|
|
51
51
|
|
|
52
52
|
input_data = load_json(input_file)
|
|
53
|
+
if "dump_data_dir" not in input_data.keys():
|
|
54
|
+
logger.error("Invalid input file, 'dump_data_dir' field is missing")
|
|
55
|
+
raise CompareException("Invalid input file, 'dump_data_dir' field is missing")
|
|
53
56
|
if input_data.get("data") is None:
|
|
54
57
|
logger.error("Invalid input file, 'data' field is missing")
|
|
55
58
|
raise CompareException("Invalid input file, 'data' field is missing")
|
|
@@ -97,7 +100,7 @@ def run_parallel_ut(config):
|
|
|
97
100
|
processes = []
|
|
98
101
|
device_id_cycle = cycle(config.device_id)
|
|
99
102
|
if config.save_error_data_flag:
|
|
100
|
-
logger.info("UT task error
|
|
103
|
+
logger.info("UT task error data will be saved")
|
|
101
104
|
logger.info(f"Starting parallel UT with {config.num_splits} processes")
|
|
102
105
|
progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
|
|
103
106
|
|
|
@@ -221,7 +224,3 @@ def main():
|
|
|
221
224
|
args = parser.parse_args()
|
|
222
225
|
config = prepare_config(args)
|
|
223
226
|
run_parallel_ut(config)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
if __name__ == '__main__':
|
|
227
|
-
main()
|
|
@@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i
|
|
|
34
34
|
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
35
35
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
36
36
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
37
|
+
from msprobe.core.common.utils import check_op_str_pattern_valid
|
|
37
38
|
from msprobe.pytorch.common.log import logger
|
|
38
39
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
40
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
def check_tensor_overflow(x):
|
|
@@ -75,6 +77,7 @@ def check_data_overflow(x, device):
|
|
|
75
77
|
return torch_npu.npu.utils.npu_check_overflow(x)
|
|
76
78
|
|
|
77
79
|
|
|
80
|
+
@recursion_depth_decorator("is_bool_output")
|
|
78
81
|
def is_bool_output(x):
|
|
79
82
|
if isinstance(x, (tuple, list)):
|
|
80
83
|
if not x:
|
|
@@ -91,6 +94,7 @@ def run_overflow_check(forward_file):
|
|
|
91
94
|
dump_path = os.path.dirname(forward_file)
|
|
92
95
|
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
93
96
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
97
|
+
check_op_str_pattern_valid(api_full_name)
|
|
94
98
|
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
95
99
|
continue
|
|
96
100
|
try:
|
|
@@ -161,6 +165,7 @@ def _run_overflow_check(parser=None):
|
|
|
161
165
|
_run_overflow_check_parser(parser)
|
|
162
166
|
args = parser.parse_args(sys.argv[1:])
|
|
163
167
|
_run_overflow_check_command(args)
|
|
168
|
+
logger.info("UT task completed.")
|
|
164
169
|
|
|
165
170
|
|
|
166
171
|
def _run_overflow_check_command(args):
|
|
@@ -175,8 +180,3 @@ def _run_overflow_check_command(args):
|
|
|
175
180
|
logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
|
|
176
181
|
raise NotImplementedError from error
|
|
177
182
|
run_overflow_check(api_info)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if __name__ == '__main__':
|
|
181
|
-
_run_overflow_check()
|
|
182
|
-
logger.info("UT task completed.")
|
|
@@ -49,7 +49,7 @@ from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
|
49
49
|
from msprobe.pytorch.common.log import logger
|
|
50
50
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
|
-
from msprobe.core.common.utils import safe_get_value, CompareException
|
|
52
|
+
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
53
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
54
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
55
55
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
@@ -65,6 +65,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
|
65
65
|
|
|
66
66
|
not_backward_list = ['repeat_interleave']
|
|
67
67
|
unsupported_backward_list = ['masked_select']
|
|
68
|
+
unsupported_api_list = ["to"]
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
tqdm_params = {
|
|
@@ -83,6 +84,9 @@ tqdm_params = {
|
|
|
83
84
|
}
|
|
84
85
|
|
|
85
86
|
|
|
87
|
+
seed_all()
|
|
88
|
+
|
|
89
|
+
|
|
86
90
|
def run_ut(config):
|
|
87
91
|
logger.info("start UT test")
|
|
88
92
|
if config.online_config.is_online:
|
|
@@ -93,7 +97,7 @@ def run_ut(config):
|
|
|
93
97
|
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
94
98
|
|
|
95
99
|
if config.save_error_data:
|
|
96
|
-
logger.info(f"UT task
|
|
100
|
+
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
|
|
97
101
|
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
98
102
|
|
|
99
103
|
if config.online_config.is_online:
|
|
@@ -117,6 +121,7 @@ def run_ut(config):
|
|
|
117
121
|
def run_api_offline(config, compare, api_name_set):
|
|
118
122
|
err_column = CompareColumn()
|
|
119
123
|
for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
|
|
124
|
+
check_op_str_pattern_valid(api_full_name)
|
|
120
125
|
if api_full_name in api_name_set:
|
|
121
126
|
continue
|
|
122
127
|
if is_unsupported_api(api_full_name):
|
|
@@ -218,6 +223,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
|
218
223
|
If api is both in black_list and black_list, black_list first.
|
|
219
224
|
return: False for exec api, True for not exec
|
|
220
225
|
"""
|
|
226
|
+
black_list.extend(unsupported_api_list)
|
|
221
227
|
if black_list and api_name in black_list:
|
|
222
228
|
return True
|
|
223
229
|
if white_list and api_name not in white_list:
|
|
@@ -317,7 +323,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
|
317
323
|
if kwargs.get("device"):
|
|
318
324
|
del kwargs["device"]
|
|
319
325
|
|
|
320
|
-
|
|
326
|
+
device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
|
|
327
|
+
device_out = exec_api(device_exec_params)
|
|
321
328
|
device_out = move2device_exec(device_out, "cpu")
|
|
322
329
|
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
323
330
|
|
|
@@ -344,6 +351,9 @@ def need_to_backward(grad_index, out):
|
|
|
344
351
|
|
|
345
352
|
def run_backward(args, grad, grad_index, out):
|
|
346
353
|
if grad_index is not None:
|
|
354
|
+
if not is_int(grad_index):
|
|
355
|
+
logger.error(f"{grad_index} dtype is not int")
|
|
356
|
+
raise TypeError(f"{grad_index} dtype is not int")
|
|
347
357
|
if grad_index >= len(out):
|
|
348
358
|
logger.error(f"Run backward error when grad_index is {grad_index}")
|
|
349
359
|
raise IndexError(f"Run backward error when grad_index is {grad_index}")
|
|
@@ -430,6 +440,7 @@ def preprocess_forward_content(forward_content):
|
|
|
430
440
|
arg_cache = {}
|
|
431
441
|
|
|
432
442
|
for key, value in forward_content.items():
|
|
443
|
+
check_op_str_pattern_valid(key)
|
|
433
444
|
base_key = key.rsplit(Const.SEP, 1)[0]
|
|
434
445
|
|
|
435
446
|
if key not in arg_cache:
|
|
@@ -469,6 +480,7 @@ def _run_ut(parser=None):
|
|
|
469
480
|
_run_ut_parser(parser)
|
|
470
481
|
args = parser.parse_args(sys.argv[1:])
|
|
471
482
|
run_ut_command(args)
|
|
483
|
+
|
|
472
484
|
|
|
473
485
|
|
|
474
486
|
def checked_online_config(online_config):
|
|
@@ -492,6 +504,7 @@ def checked_online_config(online_config):
|
|
|
492
504
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
493
505
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
494
506
|
check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
|
|
507
|
+
check_crt_valid(os.path.join(online_config.tls_path, "server.key"), True)
|
|
495
508
|
|
|
496
509
|
# host and port
|
|
497
510
|
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
@@ -561,7 +574,14 @@ def run_ut_command(args):
|
|
|
561
574
|
error_data_path = checker_config.error_data_path
|
|
562
575
|
if save_error_data:
|
|
563
576
|
if args.result_csv_path:
|
|
564
|
-
|
|
577
|
+
parts_by_dot = result_csv_path.split(Const.SEP)
|
|
578
|
+
if len(parts_by_dot) < 2 or not parts_by_dot[0]:
|
|
579
|
+
raise ValueError("result_csv_path does not contain a valid file name with an extension.")
|
|
580
|
+
file_name_part = parts_by_dot[0]
|
|
581
|
+
parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER)
|
|
582
|
+
if len(parts_by_underscore) < 2:
|
|
583
|
+
raise ValueError("File name part does not contain enough '_' separated segments.")
|
|
584
|
+
time_info = parts_by_underscore[-1]
|
|
565
585
|
global UT_ERROR_DATA_DIR
|
|
566
586
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
567
587
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
@@ -579,9 +599,8 @@ def run_ut_command(args):
|
|
|
579
599
|
}
|
|
580
600
|
run_ut_config = checker_config.get_run_ut_config(**config_params)
|
|
581
601
|
run_ut(run_ut_config)
|
|
602
|
+
logger.info("UT task completed.")
|
|
582
603
|
|
|
583
604
|
|
|
584
605
|
if __name__ == '__main__':
|
|
585
|
-
seed_all()
|
|
586
606
|
_run_ut()
|
|
587
|
-
logger.info("UT task completed.")
|
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
4
2
|
# All rights reserved.
|
|
5
3
|
#
|
|
6
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
5
|
# you may not use this file except in compliance with the License.
|
|
8
6
|
# You may obtain a copy of the License at
|
|
9
7
|
#
|
|
@@ -18,8 +16,8 @@
|
|
|
18
16
|
import os
|
|
19
17
|
from collections import namedtuple
|
|
20
18
|
import re
|
|
21
|
-
import torch
|
|
22
19
|
|
|
20
|
+
import torch
|
|
23
21
|
try:
|
|
24
22
|
import torch_npu
|
|
25
23
|
except ImportError:
|
|
@@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
|
33
31
|
from msprobe.core.common.file_utils import FileChecker
|
|
34
32
|
from msprobe.core.common.log import logger
|
|
35
33
|
from msprobe.core.common.utils import CompareException
|
|
34
|
+
from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register
|
|
36
35
|
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
37
|
-
|
|
38
|
-
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
39
|
-
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
40
|
-
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
36
|
+
|
|
41
37
|
|
|
42
38
|
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
43
39
|
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
@@ -108,17 +104,30 @@ def exec_api(exec_params):
|
|
|
108
104
|
kwargs = exec_params.kwargs
|
|
109
105
|
is_autocast = exec_params.is_autocast
|
|
110
106
|
autocast_dtype = exec_params.autocast_dtype
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if api_type
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
107
|
+
out = None
|
|
108
|
+
|
|
109
|
+
prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {})
|
|
110
|
+
if not prefix_map or api_type not in prefix_map.values() or \
|
|
111
|
+
api_type not in (
|
|
112
|
+
Const.FUNCTIONAL_API_TYPE_PREFIX,
|
|
113
|
+
Const.TENSOR_API_TYPE_PREFIX,
|
|
114
|
+
Const.TORCH_API_TYPE_PREFIX,
|
|
115
|
+
Const.ATEN_API_TYPE_PREFIX,
|
|
116
|
+
Const.NPU_API_TYPE_PREFIX
|
|
117
|
+
):
|
|
118
|
+
return out
|
|
119
|
+
|
|
120
|
+
if api_type == Const.ATEN_API_TYPE_PREFIX:
|
|
119
121
|
torch_api = AtenOPTemplate(api_name, None, False)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
+
else:
|
|
123
|
+
api_register = get_api_register()
|
|
124
|
+
api_register.initialize_hook(None)
|
|
125
|
+
api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
|
|
126
|
+
api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
|
|
127
|
+
if api_func is None:
|
|
128
|
+
return out
|
|
129
|
+
|
|
130
|
+
torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
|
|
122
131
|
if is_autocast:
|
|
123
132
|
with autocast(dtype=autocast_dtype):
|
|
124
133
|
out = torch_api.forward(*args, **kwargs)
|
|
@@ -27,6 +27,7 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T
|
|
|
27
27
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
28
28
|
from msprobe.core.common.file_utils import remove_path
|
|
29
29
|
from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
31
|
|
|
31
32
|
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
32
33
|
|
|
@@ -168,11 +169,12 @@ class ATTL:
|
|
|
168
169
|
return buffer
|
|
169
170
|
|
|
170
171
|
|
|
172
|
+
@recursion_depth_decorator("move2device_exec")
|
|
171
173
|
def move2device_exec(obj, device):
|
|
172
174
|
if isinstance(obj, (tuple, list)):
|
|
173
175
|
data_list = [move2device_exec(val, device) for val in obj]
|
|
174
176
|
return data_list if isinstance(obj, list) else tuple(data_list)
|
|
175
|
-
if isinstance(obj, dict):
|
|
177
|
+
if isinstance(obj, dict):
|
|
176
178
|
return {key: move2device_exec(val, device) for key, val in obj.items()}
|
|
177
179
|
elif isinstance(obj, torch.Tensor):
|
|
178
180
|
obj = obj.detach()
|
|
@@ -29,6 +29,8 @@ def softmax_func(x, axis=None):
|
|
|
29
29
|
|
|
30
30
|
def npu_moe_gating_top_k_softmax(x, finished_optional, k):
|
|
31
31
|
input_dtype = x.dtype
|
|
32
|
+
if x.dim() < 1:
|
|
33
|
+
raise ValueError("Input x must have at least 1 dimensions.")
|
|
32
34
|
num_expert = x.shape[-1]
|
|
33
35
|
softmax = softmax_func(x, -1)
|
|
34
36
|
softmax = softmax.to(input_dtype)
|
|
@@ -36,9 +38,13 @@ def npu_moe_gating_top_k_softmax(x, finished_optional, k):
|
|
|
36
38
|
expert_idx = expert_idx[:, :k]
|
|
37
39
|
y = torch.gather(softmax, index=expert_idx, dim=-1)
|
|
38
40
|
if finished_optional is not None:
|
|
41
|
+
if finished_optional.dim() < 1:
|
|
42
|
+
raise ValueError("Finished_optional must have at least 1 dimensions.")
|
|
39
43
|
finished_optional = finished_optional.view(finished_optional.shape[0], 1)
|
|
40
44
|
finished_optional = finished_optional.expand(-1, k)
|
|
41
45
|
expert_idx = torch.where(finished_optional, num_expert, expert_idx)
|
|
46
|
+
if y.dim() < 2:
|
|
47
|
+
raise ValueError("Variable y must have at least 2 dimensions.")
|
|
42
48
|
row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
|
|
43
49
|
|
|
44
50
|
return y, expert_idx, row_idx
|