mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
msprobe/mindspore/service.py
CHANGED
|
@@ -22,6 +22,7 @@ import mindspore as ms
|
|
|
22
22
|
from mindspore import nn
|
|
23
23
|
from mindspore.common.api import _no_grad
|
|
24
24
|
from mindspore.ops.primitive import Primitive
|
|
25
|
+
|
|
25
26
|
try:
|
|
26
27
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
27
28
|
except ImportError:
|
|
@@ -31,7 +32,7 @@ else:
|
|
|
31
32
|
|
|
32
33
|
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
33
34
|
from msprobe.core.common.file_utils import create_directory
|
|
34
|
-
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
35
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
35
36
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
36
37
|
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
|
|
37
38
|
ModuleBackwardInputs)
|
|
@@ -40,7 +41,7 @@ from msprobe.mindspore.cell_processor import CellProcessor
|
|
|
40
41
|
from msprobe.mindspore.common.log import logger
|
|
41
42
|
from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
|
|
42
43
|
is_mindtorch, register_backward_hook_functions)
|
|
43
|
-
from msprobe.mindspore.dump.hook_cell.
|
|
44
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
44
45
|
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
45
46
|
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
46
47
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
@@ -62,14 +63,19 @@ class Service:
|
|
|
62
63
|
self.inner_switch = False
|
|
63
64
|
self.primitive_switch = False
|
|
64
65
|
self.current_iter = 0
|
|
66
|
+
self.loop = 0
|
|
67
|
+
self.init_step = 0
|
|
65
68
|
self.first_start = True
|
|
66
69
|
self.current_rank = None
|
|
67
70
|
self.dump_iter_dir = None
|
|
68
71
|
self.start_call = False
|
|
69
72
|
self.should_stop_service = False
|
|
70
73
|
self.params_grad_info = {}
|
|
74
|
+
self.hook_handle_dict = {}
|
|
71
75
|
# 提前注册,确保注册尽可能多的API hook
|
|
76
|
+
self.api_register = get_api_register()
|
|
72
77
|
self.register_api_hook()
|
|
78
|
+
self.init_for_debug_level()
|
|
73
79
|
|
|
74
80
|
@staticmethod
|
|
75
81
|
def check_model_valid(models):
|
|
@@ -138,7 +144,12 @@ class Service:
|
|
|
138
144
|
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
139
145
|
for param_name, param in params_dict.items():
|
|
140
146
|
if param.requires_grad:
|
|
141
|
-
|
|
147
|
+
name = ori_name + Const.SEP + param_name
|
|
148
|
+
old_handle = self.hook_handle_dict.get(name)
|
|
149
|
+
if old_handle and hasattr(old_handle, "remove"):
|
|
150
|
+
old_handle.remove()
|
|
151
|
+
handle = param.register_hook(grad_hook(cell, ori_name, param_name))
|
|
152
|
+
self.hook_handle_dict[name] = handle
|
|
142
153
|
|
|
143
154
|
def init_params_grad_info(cell, params_dict):
|
|
144
155
|
'''
|
|
@@ -168,11 +179,15 @@ class Service:
|
|
|
168
179
|
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
|
|
169
180
|
if target_type == BaseScope.Module_Type_Module:
|
|
170
181
|
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
171
|
-
params_dict = {
|
|
172
|
-
|
|
173
|
-
|
|
182
|
+
params_dict = {}
|
|
183
|
+
if self.config.task != Const.STRUCTURE:
|
|
184
|
+
params_dict = {
|
|
185
|
+
key.split(Const.SEP)[-1]: value
|
|
186
|
+
for key, value in cell.parameters_dict(recurse=False).items()
|
|
187
|
+
}
|
|
188
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
174
189
|
# 判断是否需要注册参数hook
|
|
175
|
-
if
|
|
190
|
+
if params_dict:
|
|
176
191
|
ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
|
|
177
192
|
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
178
193
|
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
@@ -257,15 +272,33 @@ class Service:
|
|
|
257
272
|
self.primitive_counters[primitive_name] += 1
|
|
258
273
|
|
|
259
274
|
def step(self):
|
|
275
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
276
|
+
return
|
|
260
277
|
if self.config.async_dump:
|
|
261
278
|
self.data_collector.fill_stack_tensor_data()
|
|
262
|
-
self.
|
|
279
|
+
if self.config.task == Const.TENSOR:
|
|
280
|
+
self.data_collector.data_processor.dump_async_data()
|
|
263
281
|
self.data_collector.write_json()
|
|
264
|
-
self.
|
|
265
|
-
self.data_collector.update_iter(self.current_iter)
|
|
282
|
+
self.loop += 1
|
|
266
283
|
self.reset_status()
|
|
267
284
|
|
|
268
285
|
def start(self, model=None):
|
|
286
|
+
if self.current_iter == 0:
|
|
287
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
288
|
+
JitDump.set_config(self.config)
|
|
289
|
+
JitDump.set_data_collector(self.data_collector)
|
|
290
|
+
if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
|
|
291
|
+
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
292
|
+
else:
|
|
293
|
+
ms.common.api._JitExecutor = JitDump
|
|
294
|
+
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
295
|
+
if pijit_label:
|
|
296
|
+
PIJitCaptureContext.__enter__ = self.empty
|
|
297
|
+
PIJitCaptureContext.__exit__ = self.empty
|
|
298
|
+
self.current_iter = self.loop + self.init_step
|
|
299
|
+
self.data_collector.update_iter(self.current_iter)
|
|
300
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
301
|
+
return
|
|
269
302
|
self.start_call = True
|
|
270
303
|
if self.should_stop_service:
|
|
271
304
|
return
|
|
@@ -276,6 +309,7 @@ class Service:
|
|
|
276
309
|
print_tools_ends_info()
|
|
277
310
|
return
|
|
278
311
|
if self.config.step and self.current_iter not in self.config.step:
|
|
312
|
+
JitDump.jit_dump_switch = False
|
|
279
313
|
return
|
|
280
314
|
self.model = self.check_model_valid(model)
|
|
281
315
|
|
|
@@ -291,17 +325,9 @@ class Service:
|
|
|
291
325
|
return
|
|
292
326
|
self.register_primitive_hook()
|
|
293
327
|
self.register_cell_hook()
|
|
294
|
-
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
295
|
-
JitDump.set_config(self.config)
|
|
296
|
-
JitDump.set_data_collector(self.data_collector)
|
|
297
|
-
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
298
|
-
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
299
|
-
if pijit_label:
|
|
300
|
-
PIJitCaptureContext.__enter__ = self.empty
|
|
301
|
-
PIJitCaptureContext.__exit__ = self.empty
|
|
302
328
|
self.first_start = False
|
|
303
329
|
|
|
304
|
-
api_register.
|
|
330
|
+
self.api_register.register_all_api()
|
|
305
331
|
self.switch = True
|
|
306
332
|
self.primitive_switch = True
|
|
307
333
|
logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
@@ -310,6 +336,8 @@ class Service:
|
|
|
310
336
|
JitDump.jit_dump_switch = True
|
|
311
337
|
|
|
312
338
|
def stop(self):
|
|
339
|
+
if self.config.level == Const.LEVEL_DEBUG:
|
|
340
|
+
return
|
|
313
341
|
if self.should_stop_service:
|
|
314
342
|
return
|
|
315
343
|
logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
|
|
@@ -326,7 +354,8 @@ class Service:
|
|
|
326
354
|
self.start_call = False
|
|
327
355
|
if self.config.async_dump:
|
|
328
356
|
self.data_collector.fill_stack_tensor_data()
|
|
329
|
-
self.
|
|
357
|
+
if self.config.task == Const.TENSOR:
|
|
358
|
+
self.data_collector.data_processor.dump_async_data()
|
|
330
359
|
self.data_collector.write_json()
|
|
331
360
|
JitDump.jit_dump_switch = False
|
|
332
361
|
|
|
@@ -370,12 +399,13 @@ class Service:
|
|
|
370
399
|
else:
|
|
371
400
|
dump_data_dir = None
|
|
372
401
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
)
|
|
402
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
403
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
404
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
405
|
+
dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
406
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
407
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
408
|
+
|
|
379
409
|
self.data_collector.initialize_json_file(
|
|
380
410
|
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
381
411
|
)
|
|
@@ -386,21 +416,21 @@ class Service:
|
|
|
386
416
|
def register_api_hook(self):
|
|
387
417
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
388
418
|
logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
389
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
390
|
-
api_register.
|
|
419
|
+
self.api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
420
|
+
self.api_register.register_all_api()
|
|
391
421
|
|
|
392
422
|
def get_cells_and_names(self):
|
|
393
423
|
cells_and_names_with_index = {}
|
|
394
424
|
|
|
395
425
|
def get_cell_or_module(model):
|
|
396
426
|
return model.named_modules() if is_mindtorch() else model.cells_and_names()
|
|
397
|
-
|
|
427
|
+
|
|
398
428
|
if isinstance(self.model, (list, tuple)):
|
|
399
429
|
for index, model in enumerate(self.model):
|
|
400
430
|
cells_and_names_with_index[str(index)] = get_cell_or_module(model)
|
|
401
431
|
else:
|
|
402
432
|
cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
|
|
403
|
-
return cells_and_names_with_index
|
|
433
|
+
return cells_and_names_with_index
|
|
404
434
|
|
|
405
435
|
def register_primitive_hook(self):
|
|
406
436
|
if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
@@ -430,7 +460,7 @@ class Service:
|
|
|
430
460
|
if not self.model:
|
|
431
461
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
432
462
|
f"The current level is {self.config.level}, the model cannot be None")
|
|
433
|
-
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
463
|
+
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
434
464
|
cells_and_names_with_index = self.get_cells_and_names()
|
|
435
465
|
|
|
436
466
|
for index, cells_and_names in cells_and_names_with_index.items():
|
|
@@ -439,7 +469,7 @@ class Service:
|
|
|
439
469
|
if cell == model:
|
|
440
470
|
continue
|
|
441
471
|
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
442
|
-
prefix = (model_type + Const.SEP + cell_index + name +
|
|
472
|
+
prefix = (model_type + Const.SEP + cell_index + name +
|
|
443
473
|
Const.SEP + cell.__class__.__name__ + Const.SEP)
|
|
444
474
|
_, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
445
475
|
cell.register_forward_hook(forward_hook)
|
|
@@ -456,10 +486,9 @@ class Service:
|
|
|
456
486
|
|
|
457
487
|
def reset_status(self):
|
|
458
488
|
self.primitive_hook_service.primitive_counters.clear()
|
|
459
|
-
self.data_collector.
|
|
489
|
+
self.data_collector.reset_status()
|
|
460
490
|
JitDump.jit_count = defaultdict(int)
|
|
461
491
|
self.params_grad_info.clear()
|
|
462
|
-
|
|
463
492
|
if self.config.level == Const.LEVEL_L2:
|
|
464
493
|
self.data_collector.data_processor.reset_status()
|
|
465
494
|
return
|
|
@@ -467,3 +496,54 @@ class Service:
|
|
|
467
496
|
return
|
|
468
497
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
469
498
|
return
|
|
499
|
+
|
|
500
|
+
def init_for_debug_level(self):
|
|
501
|
+
if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
|
|
502
|
+
return
|
|
503
|
+
try:
|
|
504
|
+
self.current_rank = get_rank_if_initialized()
|
|
505
|
+
except DistributedNotInitializedError:
|
|
506
|
+
self.current_rank = None
|
|
507
|
+
# dir: dump_path -- rank{} -- debug.json
|
|
508
|
+
self.dump_iter_dir = self.config.dump_path
|
|
509
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
510
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
511
|
+
create_directory(dump_dir)
|
|
512
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
513
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
514
|
+
create_directory(dump_data_dir)
|
|
515
|
+
else:
|
|
516
|
+
dump_data_dir = None
|
|
517
|
+
|
|
518
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
519
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
520
|
+
dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
|
|
521
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
522
|
+
self.data_collector.initialize_json_file(
|
|
523
|
+
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
524
|
+
)
|
|
525
|
+
self.debug_variable_counter = defaultdict(int)
|
|
526
|
+
|
|
527
|
+
def save(self, variable, name, save_backward):
|
|
528
|
+
'''
|
|
529
|
+
Args:
|
|
530
|
+
variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
|
|
531
|
+
name: str
|
|
532
|
+
save_backward: boolean
|
|
533
|
+
Return:
|
|
534
|
+
void
|
|
535
|
+
'''
|
|
536
|
+
if self.config.level != Const.LEVEL_DEBUG:
|
|
537
|
+
return
|
|
538
|
+
count = self.debug_variable_counter[name]
|
|
539
|
+
self.debug_variable_counter[name] += 1
|
|
540
|
+
|
|
541
|
+
name_with_count = f"{name}.{count}"
|
|
542
|
+
grad_name_with_count = f"{name}_grad.{count}"
|
|
543
|
+
|
|
544
|
+
# forward save
|
|
545
|
+
self.data_collector.debug_data_collect_forward(variable, name_with_count)
|
|
546
|
+
|
|
547
|
+
# backward save
|
|
548
|
+
if save_backward:
|
|
549
|
+
self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
|
msprobe/pytorch/__init__.py
CHANGED
|
@@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat
|
|
|
40
40
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
|
|
41
41
|
from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
|
|
42
42
|
from msprobe.pytorch.common.log import logger
|
|
43
|
-
from msprobe.core.common.utils import CompareException
|
|
43
|
+
from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
|
|
44
44
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
45
45
|
|
|
46
46
|
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
|
|
@@ -151,6 +151,7 @@ def analyse_csv(npu_data, gpu_data, config):
|
|
|
151
151
|
message = ''
|
|
152
152
|
compare_column = ApiPrecisionOutputColumn()
|
|
153
153
|
full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
|
|
154
|
+
check_op_str_pattern_valid(full_api_name_with_direction_status)
|
|
154
155
|
row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
|
|
155
156
|
api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
|
|
156
157
|
if not api_full_name:
|
|
@@ -430,6 +431,7 @@ def _api_precision_compare(parser=None):
|
|
|
430
431
|
_api_precision_compare_parser(parser)
|
|
431
432
|
args = parser.parse_args(sys.argv[1:])
|
|
432
433
|
_api_precision_compare_command(args)
|
|
434
|
+
logger.info("Compare task completed.")
|
|
433
435
|
|
|
434
436
|
|
|
435
437
|
def _api_precision_compare_command(args):
|
|
@@ -457,8 +459,3 @@ def _api_precision_compare_parser(parser):
|
|
|
457
459
|
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
458
460
|
help="<optional> The api precision compare task result out path.",
|
|
459
461
|
required=False)
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
if __name__ == '__main__':
|
|
463
|
-
_api_precision_compare()
|
|
464
|
-
logger.info("Compare task completed.")
|
|
@@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st
|
|
|
28
28
|
ulp_standard_api, thousandth_standard_api
|
|
29
29
|
from msprobe.core.common.file_utils import FileOpen, load_json, save_json
|
|
30
30
|
from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
|
|
31
|
-
from msprobe.core.common.const import Const, MonitorConst, MsgConst
|
|
31
|
+
from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
|
|
32
32
|
from msprobe.core.common.log import logger
|
|
33
|
-
from msprobe.core.common.file_utils import make_dir
|
|
34
|
-
from msprobe.core.common.
|
|
33
|
+
from msprobe.core.common.file_utils import make_dir, change_mode
|
|
34
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
35
35
|
|
|
36
36
|
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
37
37
|
TORCH_BOOL_TYPE = ["torch.bool"]
|
|
@@ -50,6 +50,7 @@ DATA_NAME = "data_name"
|
|
|
50
50
|
API_MAX_LENGTH = 30
|
|
51
51
|
PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
|
|
52
52
|
DATAMODE_LIST = ["random_data", "real_data"]
|
|
53
|
+
ITER_MAX_TIMES = 1000
|
|
53
54
|
|
|
54
55
|
|
|
55
56
|
class APIInfo:
|
|
@@ -97,6 +98,8 @@ class CommonConfig:
|
|
|
97
98
|
iter_t = self.iter_times
|
|
98
99
|
if iter_t <= 0:
|
|
99
100
|
raise ValueError("iter_times should be an integer bigger than zero!")
|
|
101
|
+
if iter_t > ITER_MAX_TIMES:
|
|
102
|
+
raise ValueError("iter_times should not be greater than 1000!")
|
|
100
103
|
|
|
101
104
|
json_file = self.extract_api_path
|
|
102
105
|
propagation = self.propagation
|
|
@@ -117,7 +120,7 @@ class CommonConfig:
|
|
|
117
120
|
|
|
118
121
|
# Retrieve the first API name and dictionary
|
|
119
122
|
forward_item = next(iter(json_content.items()), None)
|
|
120
|
-
if not forward_item or not isinstance(forward_item[1], dict):
|
|
123
|
+
if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
|
|
121
124
|
raise ValueError(f'Invalid forward API data in json_content!')
|
|
122
125
|
|
|
123
126
|
# if propagation is backward, ensure json file contains forward and backward info
|
|
@@ -127,7 +130,7 @@ class CommonConfig:
|
|
|
127
130
|
# if propagation is backward, ensure it has valid data
|
|
128
131
|
if propagation == Const.BACKWARD:
|
|
129
132
|
backward_item = list(json_content.items())[1]
|
|
130
|
-
if not isinstance(backward_item[1], dict):
|
|
133
|
+
if not isinstance(backward_item[1], dict) or not backward_item[1]:
|
|
131
134
|
raise ValueError(f'Invalid backward API data in json_content!')
|
|
132
135
|
|
|
133
136
|
return json_content
|
|
@@ -169,7 +172,7 @@ class APIExtractor:
|
|
|
169
172
|
value = self.load_real_data_path(value, real_data_path)
|
|
170
173
|
new_data[key] = value
|
|
171
174
|
if not new_data:
|
|
172
|
-
logger.
|
|
175
|
+
logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
|
|
173
176
|
else:
|
|
174
177
|
save_json(self.output_file, new_data, indent=4)
|
|
175
178
|
logger.info(
|
|
@@ -183,6 +186,7 @@ class APIExtractor:
|
|
|
183
186
|
self.update_data_name(v, dump_data_dir)
|
|
184
187
|
return value
|
|
185
188
|
|
|
189
|
+
@recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
|
|
186
190
|
def update_data_name(self, data, dump_data_dir):
|
|
187
191
|
if isinstance(data, list):
|
|
188
192
|
for item in data:
|
|
@@ -399,7 +403,7 @@ class OperatorScriptGenerator:
|
|
|
399
403
|
def generate_kwargs_dict(self, kwargs_info, flag_device):
|
|
400
404
|
kwargs_dict_generator = ""
|
|
401
405
|
for key, value in kwargs_info.items():
|
|
402
|
-
kwargs_dict_generator += '"' + key + '"' + MonitorConst.
|
|
406
|
+
kwargs_dict_generator += '"' + key + '"' + MonitorConst.NAME_SEP
|
|
403
407
|
if flag_device:
|
|
404
408
|
kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA
|
|
405
409
|
else:
|
|
@@ -467,6 +471,7 @@ def _run_operator_generate_commond(cmd_args):
|
|
|
467
471
|
fout.write(code_template.format(**internal_settings))
|
|
468
472
|
except OSError:
|
|
469
473
|
logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
|
|
474
|
+
change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
470
475
|
|
|
471
476
|
logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
|
|
472
477
|
|
|
@@ -37,9 +37,9 @@ def load_pt(pt_path, to_cpu=False):
|
|
|
37
37
|
pt_path = os.path.realpath(pt_path)
|
|
38
38
|
try:
|
|
39
39
|
if to_cpu:
|
|
40
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
40
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
|
|
41
41
|
else:
|
|
42
|
-
pt = torch.load(pt_path)
|
|
42
|
+
pt = torch.load(pt_path, weights_only=True)
|
|
43
43
|
except Exception as e:
|
|
44
44
|
raise RuntimeError(f"load pt file {{pt_path}} failed") from e
|
|
45
45
|
return pt
|
|
@@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
50
50
|
backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
|
|
51
51
|
|
|
52
52
|
input_data = load_json(input_file)
|
|
53
|
+
if "dump_data_dir" not in input_data.keys():
|
|
54
|
+
logger.error("Invalid input file, 'dump_data_dir' field is missing")
|
|
55
|
+
raise CompareException("Invalid input file, 'dump_data_dir' field is missing")
|
|
53
56
|
if input_data.get("data") is None:
|
|
54
57
|
logger.error("Invalid input file, 'data' field is missing")
|
|
55
58
|
raise CompareException("Invalid input file, 'data' field is missing")
|
|
@@ -97,7 +100,7 @@ def run_parallel_ut(config):
|
|
|
97
100
|
processes = []
|
|
98
101
|
device_id_cycle = cycle(config.device_id)
|
|
99
102
|
if config.save_error_data_flag:
|
|
100
|
-
logger.info("UT task error
|
|
103
|
+
logger.info("UT task error data will be saved")
|
|
101
104
|
logger.info(f"Starting parallel UT with {config.num_splits} processes")
|
|
102
105
|
progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
|
|
103
106
|
|
|
@@ -221,7 +224,3 @@ def main():
|
|
|
221
224
|
args = parser.parse_args()
|
|
222
225
|
config = prepare_config(args)
|
|
223
226
|
run_parallel_ut(config)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
if __name__ == '__main__':
|
|
227
|
-
main()
|
|
@@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i
|
|
|
34
34
|
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
35
35
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
36
36
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
37
|
+
from msprobe.core.common.utils import check_op_str_pattern_valid
|
|
37
38
|
from msprobe.pytorch.common.log import logger
|
|
38
39
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
40
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
def check_tensor_overflow(x):
|
|
@@ -75,6 +77,7 @@ def check_data_overflow(x, device):
|
|
|
75
77
|
return torch_npu.npu.utils.npu_check_overflow(x)
|
|
76
78
|
|
|
77
79
|
|
|
80
|
+
@recursion_depth_decorator("is_bool_output")
|
|
78
81
|
def is_bool_output(x):
|
|
79
82
|
if isinstance(x, (tuple, list)):
|
|
80
83
|
if not x:
|
|
@@ -91,6 +94,7 @@ def run_overflow_check(forward_file):
|
|
|
91
94
|
dump_path = os.path.dirname(forward_file)
|
|
92
95
|
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
93
96
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
97
|
+
check_op_str_pattern_valid(api_full_name)
|
|
94
98
|
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
95
99
|
continue
|
|
96
100
|
try:
|
|
@@ -161,6 +165,7 @@ def _run_overflow_check(parser=None):
|
|
|
161
165
|
_run_overflow_check_parser(parser)
|
|
162
166
|
args = parser.parse_args(sys.argv[1:])
|
|
163
167
|
_run_overflow_check_command(args)
|
|
168
|
+
logger.info("UT task completed.")
|
|
164
169
|
|
|
165
170
|
|
|
166
171
|
def _run_overflow_check_command(args):
|
|
@@ -175,8 +180,3 @@ def _run_overflow_check_command(args):
|
|
|
175
180
|
logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
|
|
176
181
|
raise NotImplementedError from error
|
|
177
182
|
run_overflow_check(api_info)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if __name__ == '__main__':
|
|
181
|
-
_run_overflow_check()
|
|
182
|
-
logger.info("UT task completed.")
|
|
@@ -49,7 +49,7 @@ from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
|
49
49
|
from msprobe.pytorch.common.log import logger
|
|
50
50
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
|
-
from msprobe.core.common.utils import safe_get_value, CompareException
|
|
52
|
+
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
53
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
54
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
55
55
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
@@ -65,6 +65,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
|
65
65
|
|
|
66
66
|
not_backward_list = ['repeat_interleave']
|
|
67
67
|
unsupported_backward_list = ['masked_select']
|
|
68
|
+
unsupported_api_list = ["to"]
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
tqdm_params = {
|
|
@@ -83,6 +84,9 @@ tqdm_params = {
|
|
|
83
84
|
}
|
|
84
85
|
|
|
85
86
|
|
|
87
|
+
seed_all()
|
|
88
|
+
|
|
89
|
+
|
|
86
90
|
def run_ut(config):
|
|
87
91
|
logger.info("start UT test")
|
|
88
92
|
if config.online_config.is_online:
|
|
@@ -93,7 +97,7 @@ def run_ut(config):
|
|
|
93
97
|
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
94
98
|
|
|
95
99
|
if config.save_error_data:
|
|
96
|
-
logger.info(f"UT task
|
|
100
|
+
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
|
|
97
101
|
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
98
102
|
|
|
99
103
|
if config.online_config.is_online:
|
|
@@ -117,6 +121,7 @@ def run_ut(config):
|
|
|
117
121
|
def run_api_offline(config, compare, api_name_set):
|
|
118
122
|
err_column = CompareColumn()
|
|
119
123
|
for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
|
|
124
|
+
check_op_str_pattern_valid(api_full_name)
|
|
120
125
|
if api_full_name in api_name_set:
|
|
121
126
|
continue
|
|
122
127
|
if is_unsupported_api(api_full_name):
|
|
@@ -218,6 +223,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
|
218
223
|
If api is both in black_list and black_list, black_list first.
|
|
219
224
|
return: False for exec api, True for not exec
|
|
220
225
|
"""
|
|
226
|
+
black_list.extend(unsupported_api_list)
|
|
221
227
|
if black_list and api_name in black_list:
|
|
222
228
|
return True
|
|
223
229
|
if white_list and api_name not in white_list:
|
|
@@ -317,7 +323,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
|
317
323
|
if kwargs.get("device"):
|
|
318
324
|
del kwargs["device"]
|
|
319
325
|
|
|
320
|
-
|
|
326
|
+
device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
|
|
327
|
+
device_out = exec_api(device_exec_params)
|
|
321
328
|
device_out = move2device_exec(device_out, "cpu")
|
|
322
329
|
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
323
330
|
|
|
@@ -344,6 +351,9 @@ def need_to_backward(grad_index, out):
|
|
|
344
351
|
|
|
345
352
|
def run_backward(args, grad, grad_index, out):
|
|
346
353
|
if grad_index is not None:
|
|
354
|
+
if not is_int(grad_index):
|
|
355
|
+
logger.error(f"{grad_index} dtype is not int")
|
|
356
|
+
raise TypeError(f"{grad_index} dtype is not int")
|
|
347
357
|
if grad_index >= len(out):
|
|
348
358
|
logger.error(f"Run backward error when grad_index is {grad_index}")
|
|
349
359
|
raise IndexError(f"Run backward error when grad_index is {grad_index}")
|
|
@@ -430,6 +440,7 @@ def preprocess_forward_content(forward_content):
|
|
|
430
440
|
arg_cache = {}
|
|
431
441
|
|
|
432
442
|
for key, value in forward_content.items():
|
|
443
|
+
check_op_str_pattern_valid(key)
|
|
433
444
|
base_key = key.rsplit(Const.SEP, 1)[0]
|
|
434
445
|
|
|
435
446
|
if key not in arg_cache:
|
|
@@ -469,6 +480,7 @@ def _run_ut(parser=None):
|
|
|
469
480
|
_run_ut_parser(parser)
|
|
470
481
|
args = parser.parse_args(sys.argv[1:])
|
|
471
482
|
run_ut_command(args)
|
|
483
|
+
|
|
472
484
|
|
|
473
485
|
|
|
474
486
|
def checked_online_config(online_config):
|
|
@@ -492,6 +504,7 @@ def checked_online_config(online_config):
|
|
|
492
504
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
493
505
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
494
506
|
check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
|
|
507
|
+
check_crt_valid(os.path.join(online_config.tls_path, "server.key"), True)
|
|
495
508
|
|
|
496
509
|
# host and port
|
|
497
510
|
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
@@ -561,7 +574,14 @@ def run_ut_command(args):
|
|
|
561
574
|
error_data_path = checker_config.error_data_path
|
|
562
575
|
if save_error_data:
|
|
563
576
|
if args.result_csv_path:
|
|
564
|
-
|
|
577
|
+
parts_by_dot = result_csv_path.split(Const.SEP)
|
|
578
|
+
if len(parts_by_dot) < 2 or not parts_by_dot[0]:
|
|
579
|
+
raise ValueError("result_csv_path does not contain a valid file name with an extension.")
|
|
580
|
+
file_name_part = parts_by_dot[0]
|
|
581
|
+
parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER)
|
|
582
|
+
if len(parts_by_underscore) < 2:
|
|
583
|
+
raise ValueError("File name part does not contain enough '_' separated segments.")
|
|
584
|
+
time_info = parts_by_underscore[-1]
|
|
565
585
|
global UT_ERROR_DATA_DIR
|
|
566
586
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
567
587
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
@@ -579,9 +599,8 @@ def run_ut_command(args):
|
|
|
579
599
|
}
|
|
580
600
|
run_ut_config = checker_config.get_run_ut_config(**config_params)
|
|
581
601
|
run_ut(run_ut_config)
|
|
602
|
+
logger.info("UT task completed.")
|
|
582
603
|
|
|
583
604
|
|
|
584
605
|
if __name__ == '__main__':
|
|
585
|
-
seed_all()
|
|
586
606
|
_run_ut()
|
|
587
|
-
logger.info("UT task completed.")
|