mindstudio-probe 8.3.2__py3-none-any.whl → 26.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.2.dist-info → mindstudio_probe-26.0.0a1.dist-info}/METADATA +26 -14
- mindstudio_probe-26.0.0a1.dist-info/RECORD +498 -0
- {mindstudio_probe-8.3.2.dist-info → mindstudio_probe-26.0.0a1.dist-info}/WHEEL +1 -1
- mindstudio_probe-26.0.0a1.dist-info/entry_points.txt +5 -0
- mindstudio_probe-26.0.0a1.dist-info/licenses/LICENSE +124 -0
- mindstudio_probe-26.0.0a1.dist-info/top_level.txt +2 -0
- msprobe/__init__.py +12 -13
- msprobe/config.json +9 -31
- msprobe/core/__init__.py +12 -11
- msprobe/core/acc_check/acc_check_cli.py +145 -0
- msprobe/core/common/const.py +97 -38
- msprobe/core/common/db_manager.py +133 -12
- msprobe/core/common/decorator.py +12 -11
- msprobe/core/common/exceptions.py +12 -11
- msprobe/core/common/file_utils.py +101 -25
- msprobe/core/common/framework_adapter.py +36 -25
- msprobe/core/common/global_lock.py +12 -11
- msprobe/core/common/inplace_op_checker.py +12 -11
- msprobe/core/common/log.py +22 -11
- msprobe/core/common/megatron_utils.py +566 -11
- msprobe/core/common/parallel_state.py +12 -11
- msprobe/core/common/runtime.py +12 -11
- msprobe/core/common/utils.py +41 -41
- msprobe/core/compare/acc_compare.py +361 -104
- msprobe/core/compare/atb_data_compare.py +422 -0
- msprobe/core/compare/auto_compare.py +134 -0
- msprobe/core/compare/check.py +14 -17
- msprobe/core/compare/compare_cli.py +72 -149
- msprobe/core/compare/config.py +12 -13
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +28 -15
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/analyzer.py +18 -18
- msprobe/core/compare/find_first/graph.py +12 -11
- msprobe/core/compare/find_first/utils.py +13 -12
- msprobe/core/compare/indicator_analysis/__init__.py +15 -0
- msprobe/core/compare/indicator_analysis/algorithm.py +363 -0
- msprobe/core/compare/indicator_analysis/api_data.py +141 -0
- msprobe/core/compare/indicator_analysis/calculator.py +181 -0
- msprobe/core/compare/indicator_analysis/utils.py +116 -0
- msprobe/core/compare/layer_mapping/__init__.py +12 -11
- msprobe/core/compare/layer_mapping/data_scope_parser.py +20 -11
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -13
- msprobe/core/compare/layer_mapping/postprocess_pass.py +13 -11
- msprobe/core/compare/merge_result/merge_result.py +12 -11
- msprobe/core/compare/merge_result/merge_result_cli.py +12 -11
- msprobe/core/compare/merge_result/utils.py +12 -11
- msprobe/core/compare/multiprocessing_compute.py +13 -14
- msprobe/core/compare/npy_compare.py +13 -11
- msprobe/core/compare/offline_data_compare.py +160 -0
- msprobe/core/compare/stats_diff_calc.py +39 -0
- msprobe/core/compare/torchair_acc_cmp.py +764 -0
- msprobe/core/compare/torchair_cmp_utils.py +338 -0
- msprobe/core/compare/utils.py +140 -49
- msprobe/core/config_check/__init__.py +12 -11
- msprobe/core/config_check/checkers/__init__.py +12 -11
- msprobe/core/config_check/checkers/base_checker.py +15 -14
- msprobe/core/config_check/checkers/dataset_checker.py +13 -12
- msprobe/core/config_check/checkers/env_args_checker.py +13 -12
- msprobe/core/config_check/checkers/hyperparameter_checker.py +16 -15
- msprobe/core/config_check/checkers/pip_checker.py +15 -15
- msprobe/core/config_check/checkers/random_checker.py +13 -12
- msprobe/core/config_check/checkers/weights_checker.py +14 -12
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +13 -17
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +13 -12
- msprobe/core/config_check/ckpt_compare/metrics.py +12 -11
- msprobe/core/config_check/config_check_cli.py +18 -17
- msprobe/core/config_check/config_checker.py +16 -14
- msprobe/core/config_check/resource/dependency.yaml +15 -12
- msprobe/core/config_check/resource/env.yaml +12 -11
- msprobe/core/config_check/utils/hyperparameter_parser.py +12 -11
- msprobe/core/config_check/utils/utils.py +12 -11
- msprobe/core/{data_dump → dump/api_dump}/api_registry.py +12 -11
- msprobe/core/{common_config.py → dump/common_config.py} +13 -24
- msprobe/core/dump/data_dump/data_collector.py +257 -0
- msprobe/core/{data_dump → dump/data_dump}/data_processor/base.py +45 -36
- msprobe/core/{data_dump → dump/data_dump}/data_processor/factory.py +33 -25
- msprobe/core/{data_dump → dump/data_dump}/data_processor/mindspore_processor.py +37 -113
- msprobe/core/{data_dump → dump/data_dump}/data_processor/pytorch_processor.py +364 -131
- msprobe/core/{data_dump → dump/data_dump}/json_writer.py +24 -31
- msprobe/core/{data_dump → dump/data_dump}/scope.py +12 -13
- msprobe/core/{debugger → dump/debugger}/precision_debugger.py +15 -23
- msprobe/core/dump/dump2db/db_utils.py +215 -0
- msprobe/core/dump/dump2db/dump2db.py +409 -0
- msprobe/core/{hook_manager.py → dump/hook_manager.py} +38 -87
- msprobe/core/dump/kernel_dump/kernel_config.py +34 -0
- msprobe/core/{service.py → dump/service.py} +43 -27
- msprobe/core/install_deps/install_deps.py +51 -0
- msprobe/core/monitor/anomaly_processor.py +13 -11
- msprobe/core/monitor/csv2db.py +73 -93
- msprobe/core/monitor/db_utils.py +140 -205
- msprobe/core/monitor/utils.py +18 -17
- msprobe/core/monitor_v2/__init__.py +20 -0
- msprobe/core/monitor_v2/base.py +83 -0
- msprobe/core/monitor_v2/cc.py +287 -0
- msprobe/core/monitor_v2/factory.py +81 -0
- msprobe/core/monitor_v2/module.py +201 -0
- msprobe/core/monitor_v2/optimizer.py +245 -0
- msprobe/core/monitor_v2/param.py +154 -0
- msprobe/core/monitor_v2/trainer.py +326 -0
- msprobe/core/monitor_v2/utils.py +122 -0
- msprobe/core/monitor_v2/weight_grad.py +419 -0
- msprobe/core/monitor_v2/writer.py +162 -0
- msprobe/core/overflow_check/abnormal_scene.py +12 -11
- msprobe/core/overflow_check/api_info.py +12 -11
- msprobe/core/overflow_check/checker.py +12 -11
- msprobe/core/overflow_check/filter.py +13 -11
- msprobe/core/overflow_check/level.py +12 -11
- msprobe/core/overflow_check/utils.py +12 -11
- msprobe/core/single_save/single_comparator.py +12 -11
- msprobe/core/single_save/single_saver.py +12 -11
- msprobe/infer/__init__.py +16 -0
- msprobe/infer/offline/__init__.py +16 -0
- msprobe/infer/offline/compare/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/adapter_cli/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/adapter_cli/args_adapter.py +46 -0
- msprobe/infer/offline/compare/msquickcmp/atc/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/atc/atc_utils.py +98 -0
- msprobe/infer/offline/compare/msquickcmp/cmp_process.py +328 -0
- msprobe/infer/offline/compare/msquickcmp/common/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/common/args_check.py +112 -0
- msprobe/infer/offline/compare/msquickcmp/common/convert.py +74 -0
- msprobe/infer/offline/compare/msquickcmp/common/dump_data.py +121 -0
- msprobe/infer/offline/compare/msquickcmp/common/dynamic_argument_bean.py +39 -0
- msprobe/infer/offline/compare/msquickcmp/common/utils.py +669 -0
- msprobe/infer/offline/compare/msquickcmp/config.ini +6 -0
- msprobe/infer/offline/compare/msquickcmp/dump/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/dump/args_adapter.py +50 -0
- msprobe/infer/offline/compare/msquickcmp/dump/dump_process.py +91 -0
- msprobe/infer/offline/compare/msquickcmp/install_aclruntime_aisbench.sh +180 -0
- msprobe/infer/offline/compare/msquickcmp/main.py +199 -0
- msprobe/infer/offline/compare/msquickcmp/net_compare/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/net_compare/net_compare.py +277 -0
- msprobe/infer/offline/compare/msquickcmp/npu/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/npu/npu_dump_data.py +558 -0
- msprobe/infer/offline/compare/msquickcmp/npu/om_parser.py +416 -0
- msprobe/infer/offline/compare/msquickcmp/onnx_model/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/onnx_model/onnx_dump_data.py +374 -0
- msprobe/infer/utils/__init__.py +15 -0
- msprobe/infer/utils/acc_cmp.py +94 -0
- msprobe/infer/utils/check/__init__.py +37 -0
- msprobe/infer/utils/check/args_checker.py +35 -0
- msprobe/infer/utils/check/checker.py +227 -0
- msprobe/infer/utils/check/dict_checker.py +78 -0
- msprobe/infer/utils/check/func_wrapper.py +96 -0
- msprobe/infer/utils/check/list_checker.py +56 -0
- msprobe/infer/utils/check/number_checker.py +64 -0
- msprobe/infer/utils/check/obj_checker.py +41 -0
- msprobe/infer/utils/check/path_checker.py +249 -0
- msprobe/infer/utils/check/rule.py +126 -0
- msprobe/infer/utils/check/string_checker.py +66 -0
- msprobe/infer/utils/cmp_algorithm.py +261 -0
- msprobe/infer/utils/constants.py +112 -0
- msprobe/infer/utils/file_open_check.py +337 -0
- msprobe/infer/utils/util.py +177 -0
- msprobe/mindspore/__init__.py +14 -13
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +14 -13
- msprobe/mindspore/api_accuracy_checker/api_info.py +12 -11
- msprobe/mindspore/api_accuracy_checker/api_runner.py +12 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +12 -11
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +12 -11
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +12 -11
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +12 -11
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +15 -14
- msprobe/mindspore/api_accuracy_checker/compute_element.py +12 -11
- msprobe/mindspore/api_accuracy_checker/data_manager.py +13 -11
- msprobe/mindspore/api_accuracy_checker/main.py +12 -11
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +14 -12
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +13 -11
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +12 -11
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +12 -11
- msprobe/mindspore/api_accuracy_checker/utils.py +12 -11
- msprobe/mindspore/common/const.py +15 -74
- msprobe/mindspore/common/log.py +12 -11
- msprobe/mindspore/common/utils.py +30 -15
- msprobe/mindspore/compare/common_dir_compare.py +21 -23
- msprobe/mindspore/compare/distributed_compare.py +18 -16
- msprobe/mindspore/compare/ms_compare.py +14 -14
- msprobe/mindspore/compare/ms_graph_compare.py +26 -20
- msprobe/mindspore/compare/utils.py +14 -12
- msprobe/mindspore/{cell_processor.py → dump/cell_processor.py} +15 -14
- msprobe/mindspore/{debugger → dump/debugger}/debugger_config.py +12 -30
- msprobe/mindspore/{debugger → dump/debugger}/precision_debugger.py +43 -45
- msprobe/mindspore/dump/{cell_dump_process.py → dump_processor/cell_dump_process.py} +31 -17
- msprobe/mindspore/dump/{cell_dump_with_insert_gradient.py → dump_processor/cell_dump_with_insert_gradient.py} +18 -14
- msprobe/mindspore/dump/{dump_tool_factory.py → dump_processor/dump_tool_factory.py} +16 -15
- msprobe/mindspore/dump/{graph_mode_cell_dump.py → dump_processor/graph_mode_cell_dump.py} +16 -15
- msprobe/mindspore/dump/{graph_tensor_dump.py → dump_processor/graph_tensor_dump.py} +134 -133
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/api_register.py +15 -14
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/hook_cell.py +12 -11
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/ms_hook_manager.py +47 -20
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/primitive_hooks.py +14 -13
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/support_wrap_ops.yaml +13 -11
- msprobe/mindspore/dump/{jit_dump.py → dump_processor/jit_dump.py} +14 -13
- msprobe/mindspore/dump/{kernel_graph_dump.py → dump_processor/kernel_graph_dump.py} +13 -12
- msprobe/mindspore/dump/{kernel_kbyk_dump.py → dump_processor/kernel_kbyk_dump.py} +13 -12
- msprobe/mindspore/{exception_dump → dump/exception_dump}/exception_dump_tool_factory.py +14 -13
- msprobe/mindspore/{exception_dump → dump/exception_dump}/kernel_graph_exception_dump.py +13 -12
- msprobe/mindspore/{mindspore_service.py → dump/mindspore_service.py} +18 -17
- msprobe/mindspore/dump/mindtorch/__init__.py +19 -0
- msprobe/mindspore/dump/ms_config.py +105 -0
- msprobe/mindspore/{overflow_check → dump/overflow_check}/kernel_graph_overflow_check.py +13 -12
- msprobe/mindspore/{overflow_check → dump/overflow_check}/overflow_check_tool_factory.py +14 -13
- msprobe/mindspore/dump/task_handler_factory.py +43 -0
- msprobe/mindspore/monitor/common_func.py +12 -11
- msprobe/mindspore/monitor/data_writers.py +12 -11
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +93 -39
- msprobe/mindspore/monitor/features.py +12 -11
- msprobe/mindspore/monitor/module_hook.py +19 -22
- msprobe/mindspore/monitor/optimizer_collect.py +29 -25
- msprobe/mindspore/monitor/utils.py +13 -11
- msprobe/msaccucmp/advisor/__init__.py +16 -0
- msprobe/msaccucmp/advisor/advisor_const.py +65 -0
- msprobe/msaccucmp/advisor/advisor_result.py +73 -0
- msprobe/msaccucmp/advisor/compare_advisor.py +99 -0
- msprobe/msaccucmp/advisor/input_advisor.py +66 -0
- msprobe/msaccucmp/advisor/node_advisor.py +68 -0
- msprobe/msaccucmp/advisor/overflow_advisor.py +58 -0
- msprobe/msaccucmp/algorithm_manager/__init__.py +16 -0
- msprobe/msaccucmp/algorithm_manager/algorithm_manager.py +464 -0
- msprobe/msaccucmp/algorithm_manager/algorithm_parameter.py +42 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_AccumulatedRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_CosineSimilarity.py +58 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_KullbackLeiblerDivergence.py +84 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxAbsoluteError.py +41 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanAbsoluteError.py +41 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RelativeEuclideanDistance.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RootMeanSquareError.py +40 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_StandardDeviation.py +47 -0
- msprobe/msaccucmp/cmp_utils/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/common.py +113 -0
- msprobe/msaccucmp/cmp_utils/constant/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/constant/compare_error.py +81 -0
- msprobe/msaccucmp/cmp_utils/constant/const_manager.py +530 -0
- msprobe/msaccucmp/cmp_utils/file_utils.py +497 -0
- msprobe/msaccucmp/cmp_utils/log.py +257 -0
- msprobe/msaccucmp/cmp_utils/multi_process/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/multi_process/multi_convert_process.py +140 -0
- msprobe/msaccucmp/cmp_utils/multi_process/progress.py +78 -0
- msprobe/msaccucmp/cmp_utils/path_check.py +274 -0
- msprobe/msaccucmp/cmp_utils/reg_manager.py +98 -0
- msprobe/msaccucmp/cmp_utils/tlv_parse.py +279 -0
- msprobe/msaccucmp/cmp_utils/utils.py +356 -0
- msprobe/msaccucmp/cmp_utils/utils_type.py +63 -0
- msprobe/msaccucmp/compare_vector.py +48 -0
- msprobe/msaccucmp/conversion/__init__.py +16 -0
- msprobe/msaccucmp/conversion/data_conversion.py +277 -0
- msprobe/msaccucmp/conversion/dtype_conversion.py +99 -0
- msprobe/msaccucmp/conversion/shape_format_conversion.py +477 -0
- msprobe/msaccucmp/conversion/tensor_conversion.py +369 -0
- msprobe/msaccucmp/dump_data_conversion.py +46 -0
- msprobe/msaccucmp/dump_parse/__init__.py +16 -0
- msprobe/msaccucmp/dump_parse/big_dump_data.py +317 -0
- msprobe/msaccucmp/dump_parse/dump.py +423 -0
- msprobe/msaccucmp/dump_parse/dump_data_object.py +322 -0
- msprobe/msaccucmp/dump_parse/dump_data_parser.py +436 -0
- msprobe/msaccucmp/dump_parse/dump_utils.py +246 -0
- msprobe/msaccucmp/dump_parse/ffts_parser.py +137 -0
- msprobe/msaccucmp/dump_parse/mapping.py +62 -0
- msprobe/msaccucmp/dump_parse/nano_dump_data.py +392 -0
- msprobe/msaccucmp/dump_parse/proto_dump_data.py +308 -0
- msprobe/msaccucmp/dump_parser.py +90 -0
- msprobe/msaccucmp/format_manager/__init__.py +16 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NCHW.py +53 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_ND.py +52 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NHWC.py +53 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_HWCN.py +47 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_NCHW.py +47 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_FRACTAL_Z.py +89 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NCHW.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NHWC.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_HWCN.py +43 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NCHW.py +48 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NHWC.py +43 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_FRACTAL_Z.py +87 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_NHWC.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_NCDHW.py +48 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_ND.py +44 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_FRACTAL_Z.py +87 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_HWCN.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_NCHW.py +37 -0
- msprobe/msaccucmp/format_manager/format_manager.py +307 -0
- msprobe/msaccucmp/inplace_layer_process.py +186 -0
- msprobe/msaccucmp/msaccucmp.py +532 -0
- msprobe/msaccucmp/mscmp_advisor.py +128 -0
- msprobe/msaccucmp/overflow/__init__.py +16 -0
- msprobe/msaccucmp/overflow/overflow_analyse.py +305 -0
- msprobe/msaccucmp/overflow/overflow_detection.py +143 -0
- msprobe/msaccucmp/pytorch_cmp/__init__.py +16 -0
- msprobe/msaccucmp/pytorch_cmp/compare_pytorch.py +389 -0
- msprobe/msaccucmp/pytorch_cmp/hdf5_parser.py +377 -0
- msprobe/msaccucmp/pytorch_cmp/pytorch_dump_data.py +461 -0
- msprobe/msaccucmp/shape_conversion.py +41 -0
- msprobe/msaccucmp/vector_cmp/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/batch_compare.py +197 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/compare_detail.py +245 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/detail.py +182 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/detail_writer.py +580 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_fusion_op.py +588 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_npu_vs_npu.py +339 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_result.py +326 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_rule.py +156 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_op.py +204 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_rule_parser.py +635 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/quant_filter.py +187 -0
- msprobe/msaccucmp/vector_cmp/range_manager/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/range_manager/range_manager.py +100 -0
- msprobe/msaccucmp/vector_cmp/range_manager/range_mode.py +94 -0
- msprobe/msaccucmp/vector_cmp/range_manager/select_mode.py +86 -0
- msprobe/msaccucmp/vector_cmp/vector_comparison.py +535 -0
- msprobe/msprobe.py +101 -130
- msprobe/overflow_check/__init__.py +15 -0
- msprobe/{nan_analyze → overflow_check}/analyzer.py +38 -27
- msprobe/{nan_analyze → overflow_check}/graph.py +30 -27
- msprobe/{nan_analyze → overflow_check}/utils.py +15 -14
- msprobe/pytorch/__init__.py +20 -14
- msprobe/pytorch/aclgraph_dump/__init__.py +45 -0
- msprobe/pytorch/aclgraph_dump/_meta.py +26 -0
- msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut.py → acc_check/acc_check.py} +50 -45
- msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut_utils.py → acc_check/acc_check_utils.py} +201 -30
- msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/data_generate.py +56 -16
- msprobe/pytorch/api_accuracy_checker/{run_ut/multi_run_ut.py → acc_check/multi_acc_check.py} +32 -47
- msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/run_overflow_check.py +19 -18
- msprobe/pytorch/api_accuracy_checker/common/config.py +22 -20
- msprobe/pytorch/api_accuracy_checker/common/utils.py +72 -13
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -11
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +23 -14
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +45 -32
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +12 -11
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +14 -12
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +14 -12
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +21 -19
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +14 -13
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +60 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +27 -16
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +13 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +39 -18
- msprobe/pytorch/bench_functions/__init__.py +12 -11
- msprobe/pytorch/bench_functions/apply_adam.py +12 -11
- msprobe/pytorch/bench_functions/apply_adam_w.py +12 -11
- msprobe/pytorch/bench_functions/confusion_transpose.py +12 -11
- msprobe/pytorch/bench_functions/fast_gelu.py +12 -11
- msprobe/pytorch/bench_functions/group_norm_silu.py +12 -11
- msprobe/pytorch/bench_functions/layer_norm_eval.py +12 -11
- msprobe/pytorch/bench_functions/linear.py +12 -11
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -11
- msprobe/pytorch/bench_functions/mish.py +12 -11
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +12 -11
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +12 -11
- msprobe/pytorch/bench_functions/rms_norm.py +12 -11
- msprobe/pytorch/bench_functions/rotary_mul.py +12 -11
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +12 -11
- msprobe/pytorch/bench_functions/sort_v2.py +12 -11
- msprobe/pytorch/bench_functions/swiglu.py +12 -11
- msprobe/pytorch/common/__init__.py +12 -11
- msprobe/pytorch/common/log.py +12 -11
- msprobe/pytorch/common/parse_json.py +12 -11
- msprobe/pytorch/common/utils.py +52 -19
- msprobe/pytorch/compare/distributed_compare.py +13 -13
- msprobe/pytorch/compare/match.py +12 -11
- msprobe/pytorch/compare/pt_compare.py +14 -20
- msprobe/pytorch/compare/pt_diff_analyze.py +12 -11
- msprobe/pytorch/compare/utils.py +12 -11
- msprobe/pytorch/{hook_module → dump/api_dump}/api_register.py +18 -16
- msprobe/pytorch/{hook_module → dump/api_dump}/hook_module.py +14 -13
- msprobe/pytorch/{hook_module → dump/api_dump}/pt_hook_manager.py +68 -23
- msprobe/pytorch/{hook_module → dump/api_dump}/register_optimizer_hook.py +13 -11
- msprobe/pytorch/{hook_module → dump/api_dump}/script_wrapper.py +17 -14
- msprobe/pytorch/{hook_module → dump/api_dump}/utils.py +12 -11
- msprobe/pytorch/{debugger → dump/debugger}/debugger_config.py +23 -38
- msprobe/pytorch/dump/debugger/precision_debugger.py +130 -0
- msprobe/pytorch/{function_factory.py → dump/function_factory.py} +12 -11
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +17 -13
- msprobe/pytorch/dump/module_dump/module_dump.py +16 -15
- msprobe/pytorch/dump/module_dump/{module_processer.py → module_processor.py} +54 -42
- msprobe/pytorch/dump/pt_config.py +128 -0
- msprobe/pytorch/{pytorch_service.py → dump/pytorch_service.py} +22 -21
- msprobe/pytorch/monitor/csv2tb.py +13 -11
- msprobe/pytorch/monitor/data_writers.py +13 -11
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +13 -11
- msprobe/pytorch/monitor/features.py +12 -11
- msprobe/pytorch/monitor/module_hook.py +67 -59
- msprobe/pytorch/monitor/module_metric.py +13 -11
- msprobe/pytorch/monitor/optimizer_collect.py +37 -35
- msprobe/pytorch/monitor/utils.py +13 -11
- msprobe/pytorch/monitor/visualizer.py +12 -11
- msprobe/pytorch/torchair_dump/__init__.py +17 -0
- msprobe/pytorch/torchair_dump/torchair_dump.py +114 -0
- msprobe/scripts/atb/config_example.json +10 -0
- msprobe/scripts/atb/load_atb_probe.sh +101 -0
- msprobe/scripts/atb/unload_atb_probe.sh +27 -0
- msprobe/scripts/build_msaccucmp.sh +186 -0
- msprobe/scripts/conf/help.info +6 -0
- msprobe/scripts/conf/version.info +3 -0
- msprobe/scripts/run_script/common.sh +538 -0
- msprobe/scripts/run_script/main_msaccucmp.sh +232 -0
- msprobe/visualization/__init__.py +12 -11
- msprobe/visualization/builder/__init__.py +12 -11
- msprobe/visualization/builder/graph_builder.py +45 -30
- msprobe/visualization/builder/graph_merger.py +53 -32
- msprobe/visualization/builder/msprobe_adapter.py +34 -44
- msprobe/visualization/compare/__init__.py +12 -11
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +28 -113
- msprobe/visualization/db_utils.py +133 -22
- msprobe/visualization/graph/__init__.py +12 -11
- msprobe/visualization/graph/base_node.py +15 -27
- msprobe/visualization/graph/distributed_analyzer.py +97 -40
- msprobe/visualization/graph/graph.py +14 -16
- msprobe/visualization/graph/node_colors.py +34 -31
- msprobe/visualization/graph/node_op.py +12 -11
- msprobe/visualization/graph_service.py +580 -205
- msprobe/visualization/utils.py +278 -31
- tb_graph_ascend/secure_build.py +175 -0
- tb_graph_ascend/server/__init__.py +15 -0
- tb_graph_ascend/server/app/__init__.py +15 -0
- tb_graph_ascend/server/app/model/__init__.py +15 -0
- tb_graph_ascend/server/app/model/hierarchy.py +348 -0
- tb_graph_ascend/server/app/model/layout_hierarchy_model.py +69 -0
- tb_graph_ascend/server/app/model/match_nodes_model.py +573 -0
- tb_graph_ascend/server/app/repositories/__init__.py +15 -0
- tb_graph_ascend/server/app/repositories/graph_repo_base.py +32 -0
- tb_graph_ascend/server/app/repositories/graph_repo_db.py +879 -0
- tb_graph_ascend/server/app/repositories/graph_repo_vis.py +83 -0
- tb_graph_ascend/server/app/service/__init__.py +18 -0
- tb_graph_ascend/server/app/service/graph_service_base.py +158 -0
- tb_graph_ascend/server/app/service/graph_service_db.py +438 -0
- tb_graph_ascend/server/app/service/graph_service_factory.py +54 -0
- tb_graph_ascend/server/app/service/graph_service_vis.py +480 -0
- tb_graph_ascend/server/app/utils/__init__.py +15 -0
- tb_graph_ascend/server/app/utils/constant.py +80 -0
- tb_graph_ascend/server/app/utils/file_check_wrapper.py +46 -0
- tb_graph_ascend/server/app/utils/global_state.py +95 -0
- tb_graph_ascend/server/app/utils/graph_utils.py +661 -0
- tb_graph_ascend/server/app/utils/i18n.py +153 -0
- tb_graph_ascend/server/app/utils/request_method.py +46 -0
- tb_graph_ascend/server/app/views/__init__.py +15 -0
- tb_graph_ascend/server/app/views/graph_views.py +304 -0
- tb_graph_ascend/server/plugin.py +108 -0
- tb_graph_ascend/server/static/index.html +9250 -0
- tb_graph_ascend/server/static/index.js +21 -0
- tb_graph_ascend/setup.py +57 -0
- mindstudio_probe-8.3.2.dist-info/LICENSE +0 -201
- mindstudio_probe-8.3.2.dist-info/RECORD +0 -491
- mindstudio_probe-8.3.2.dist-info/entry_points.txt +0 -2
- mindstudio_probe-8.3.2.dist-info/top_level.txt +0 -1
- msprobe/CMakeLists.txt +0 -5
- msprobe/README.md +0 -203
- msprobe/core/advisor/advisor.py +0 -129
- msprobe/core/advisor/advisor_const.py +0 -58
- msprobe/core/advisor/advisor_result.py +0 -58
- msprobe/core/compare/find_first/data_processor.py +0 -35
- msprobe/core/compare/highlight.py +0 -390
- msprobe/core/data_dump/data_collector.py +0 -356
- msprobe/core/grad_probe/constant.py +0 -90
- msprobe/core/grad_probe/grad_compare.py +0 -187
- msprobe/core/grad_probe/utils.py +0 -105
- msprobe/core/kernel_dump/kernel_config.py +0 -33
- msprobe/docs/01.installation.md +0 -250
- msprobe/docs/02.config_introduction.md +0 -221
- msprobe/docs/03.config_examples.md +0 -281
- msprobe/docs/04.kernel_dump_PyTorch.md +0 -73
- msprobe/docs/05.data_dump_PyTorch.md +0 -518
- msprobe/docs/06.data_dump_MindSpore.md +0 -618
- msprobe/docs/07.accuracy_checker_PyTorch.md +0 -310
- msprobe/docs/09.accuracy_checker_MindSpore.md +0 -120
- msprobe/docs/10.accuracy_compare_PyTorch.md +0 -637
- msprobe/docs/11.accuracy_compare_MindSpore.md +0 -769
- msprobe/docs/12.overflow_check_PyTorch.md +0 -82
- msprobe/docs/13.overflow_check_MindSpore.md +0 -33
- msprobe/docs/14.data_parse_PyTorch.md +0 -282
- msprobe/docs/15.free_benchmarking_PyTorch.md +0 -169
- msprobe/docs/16.free_benchmarking_MindSpore.md +0 -159
- msprobe/docs/17.grad_probe.md +0 -205
- msprobe/docs/18.online_dispatch.md +0 -89
- msprobe/docs/19.monitor.md +0 -753
- msprobe/docs/20.monitor_performance_baseline.md +0 -52
- msprobe/docs/21.visualization_PyTorch.md +0 -519
- msprobe/docs/22.visualization_MindSpore.md +0 -515
- msprobe/docs/23.generate_operator_PyTorch.md +0 -107
- msprobe/docs/24.code_mapping_Mindspore.md +0 -29
- msprobe/docs/25.tool_function_introduction.md +0 -29
- msprobe/docs/26.data_dump_PyTorch_baseline.md +0 -48
- msprobe/docs/27.dump_json_instruction.md +0 -795
- msprobe/docs/28.debugger_save_instruction.md +0 -288
- msprobe/docs/28.kernel_dump_MindSpore.md +0 -69
- msprobe/docs/29.data_dump_MSAdapter.md +0 -235
- msprobe/docs/30.overflow_check_MSAdapter.md +0 -31
- msprobe/docs/31.config_check.md +0 -107
- msprobe/docs/32.ckpt_compare.md +0 -69
- msprobe/docs/33.generate_operator_MindSpore.md +0 -181
- msprobe/docs/34.RL_collect.md +0 -101
- msprobe/docs/35.nan_analyze.md +0 -73
- msprobe/docs/36.calculation_result_change.md +0 -75
- msprobe/docs/FAQ.md +0 -232
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +0 -146
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +0 -14
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +0 -33
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +0 -217
- msprobe/docs/img/BLOOM-7B_1.png +0 -0
- msprobe/docs/img/BLOOM-7B_2.png +0 -0
- msprobe/docs/img/BLOOM-7B_3.png +0 -0
- msprobe/docs/img/BLOOM-7B_4.png +0 -0
- msprobe/docs/img/GPT-3_1.png +0 -0
- msprobe/docs/img/GPT-3_2.png +0 -0
- msprobe/docs/img/GPT-3_3.png +0 -0
- msprobe/docs/img/GPT-3_4.png +0 -0
- msprobe/docs/img/GPT-3_5.png +0 -0
- msprobe/docs/img/GPT-3_6.png +0 -0
- msprobe/docs/img/GPT-3_7.png +0 -0
- msprobe/docs/img/GPT-3_8.png +0 -0
- msprobe/docs/img/YOLOV5S_1.png +0 -0
- msprobe/docs/img/YOLOV5S_2.png +0 -0
- msprobe/docs/img/accuracy_checking_details.png +0 -0
- msprobe/docs/img/accuracy_checking_result.png +0 -0
- msprobe/docs/img/api_precision_compare_details.png +0 -0
- msprobe/docs/img/api_precision_compare_result.png +0 -0
- msprobe/docs/img/auto_analyze_log.png +0 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/compare_result_pkl.png +0 -0
- msprobe/docs/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/docs/img/cpu_info.png +0 -0
- msprobe/docs/img/free_benchmark.png +0 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/module_compare.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +0 -132
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +0 -59
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +0 -80
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +0 -330
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +0 -460
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +0 -2081
- msprobe/mindspore/code_mapping/bind.py +0 -283
- msprobe/mindspore/code_mapping/cmd_parser.py +0 -40
- msprobe/mindspore/code_mapping/graph.py +0 -49
- msprobe/mindspore/code_mapping/graph_parser.py +0 -211
- msprobe/mindspore/code_mapping/main.py +0 -24
- msprobe/mindspore/code_mapping/processor.py +0 -34
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +0 -111
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -52
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +0 -257
- msprobe/mindspore/free_benchmark/common/config.py +0 -27
- msprobe/mindspore/free_benchmark/common/handler_params.py +0 -31
- msprobe/mindspore/free_benchmark/common/utils.py +0 -100
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -638
- msprobe/mindspore/free_benchmark/handler/base_handler.py +0 -105
- msprobe/mindspore/free_benchmark/handler/check_handler.py +0 -55
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +0 -51
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +0 -36
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +0 -82
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +0 -45
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +0 -78
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +0 -77
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +0 -56
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +0 -27
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +0 -46
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +0 -51
- msprobe/mindspore/grad_probe/global_context.py +0 -127
- msprobe/mindspore/grad_probe/grad_analyzer.py +0 -260
- msprobe/mindspore/grad_probe/grad_monitor.py +0 -42
- msprobe/mindspore/grad_probe/grad_stat_csv.py +0 -161
- msprobe/mindspore/grad_probe/hook.py +0 -115
- msprobe/mindspore/grad_probe/utils.py +0 -43
- msprobe/mindspore/mindtorch/__init__.py +0 -18
- msprobe/mindspore/ms_config.py +0 -153
- msprobe/mindspore/task_handler_factory.py +0 -44
- msprobe/nan_analyze/__init__.py +0 -14
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +0 -9
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +0 -480
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +0 -567
- msprobe/pytorch/debugger/precision_debugger.py +0 -181
- msprobe/pytorch/free_benchmark/__init__.py +0 -23
- msprobe/pytorch/free_benchmark/common/constant.py +0 -85
- msprobe/pytorch/free_benchmark/common/counter.py +0 -87
- msprobe/pytorch/free_benchmark/common/enums.py +0 -80
- msprobe/pytorch/free_benchmark/common/params.py +0 -152
- msprobe/pytorch/free_benchmark/common/utils.py +0 -143
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -215
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +0 -121
- msprobe/pytorch/free_benchmark/main.py +0 -123
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +0 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +0 -56
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +0 -107
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +0 -121
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +0 -89
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +0 -87
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +0 -43
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +0 -60
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +0 -34
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +0 -252
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +0 -54
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +0 -40
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -45
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -181
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +0 -108
- msprobe/pytorch/grad_probe/grad_stat_csv.py +0 -160
- msprobe/pytorch/hook_module/__init__.py +0 -16
- msprobe/pytorch/hook_module/wrap_aten.py +0 -111
- msprobe/pytorch/online_dispatch/__init__.py +0 -19
- msprobe/pytorch/online_dispatch/compare.py +0 -224
- msprobe/pytorch/online_dispatch/dispatch.py +0 -332
- msprobe/pytorch/online_dispatch/dump_compare.py +0 -179
- msprobe/pytorch/online_dispatch/single_compare.py +0 -412
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +0 -58
- msprobe/pytorch/online_dispatch/utils.py +0 -158
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +0 -31
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +0 -253
- msprobe/pytorch/parse_tool/lib/config.py +0 -50
- msprobe/pytorch/parse_tool/lib/file_desc.py +0 -45
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +0 -97
- msprobe/pytorch/parse_tool/lib/parse_exception.py +0 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +0 -161
- msprobe/pytorch/parse_tool/lib/utils.py +0 -299
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -85
- msprobe/pytorch/pt_config.py +0 -299
- /msprobe/core/{grad_probe → dump}/__init__.py +0 -0
- /msprobe/{mindspore/code_mapping → core/dump/api_dump}/__init__.py +0 -0
- /msprobe/{mindspore/debugger → core/dump/data_dump}/__init__.py +0 -0
- /msprobe/{mindspore/exception_dump → core/dump/data_dump/data_processor}/__init__.py +0 -0
- /msprobe/{mindspore/free_benchmark → core/dump/debugger}/__init__.py +0 -0
- /msprobe/{mindspore/free_benchmark/common → core/dump/kernel_dump}/__init__.py +0 -0
- /msprobe/mindspore/{free_benchmark/handler → dump/debugger}/__init__.py +0 -0
- /msprobe/mindspore/{grad_probe → dump/dump_processor}/__init__.py +0 -0
- /msprobe/mindspore/{overflow_check → dump/exception_dump}/__init__.py +0 -0
- /msprobe/mindspore/{mindtorch → dump/mindtorch}/mindtorch_adaptor.py +0 -0
- /msprobe/{pytorch/api_accuracy_checker/run_ut → mindspore/dump/overflow_check}/__init__.py +0 -0
- /msprobe/{pytorch/debugger → mindspore/monitor}/__init__.py +0 -0
- /msprobe/{pytorch/free_benchmark/common → msaccucmp}/__init__.py +0 -0
- /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/.keep +0 -0
- /msprobe/pytorch/{free_benchmark/perturbed_layers → api_accuracy_checker/acc_check}/__init__.py +0 -0
- /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/torch_ut_setting.json +0 -0
- /msprobe/pytorch/{free_benchmark/perturbed_layers/npu → dump/api_dump}/__init__.py +0 -0
- /msprobe/pytorch/{hook_module → dump/api_dump}/support_wrap_ops.yaml +0 -0
- /msprobe/pytorch/{free_benchmark/result_handlers → dump/debugger}/__init__.py +0 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# This file is part of the MindStudio project.
|
|
3
|
+
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
|
|
4
|
+
#
|
|
5
|
+
# MindStudio is licensed under Mulan PSL v2.
|
|
6
|
+
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
|
7
|
+
# You may obtain a copy of Mulan PSL v2 at:
|
|
8
|
+
#
|
|
9
|
+
# http://license.coscl.org.cn/MulanPSL2
|
|
10
|
+
#
|
|
11
|
+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
|
12
|
+
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
|
13
|
+
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
|
14
|
+
# See the Mulan PSL v2 for more details.
|
|
15
|
+
# -------------------------------------------------------------------------
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch_npu
|
|
20
|
+
|
|
21
|
+
# Import the C++ extension to register TORCH_LIBRARY implementations.
|
|
22
|
+
try:
|
|
23
|
+
from msprobe.lib import aclgraph_dump_ext # noqa: F401
|
|
24
|
+
except Exception as exc:
|
|
25
|
+
raise RuntimeError(f"Failed to import msprobe.lib.aclgraph_dump_ext: {exc}")
|
|
26
|
+
|
|
27
|
+
# Register Python fake implementation for meta tensors.
|
|
28
|
+
from ._meta import _register_meta # noqa: E402
|
|
29
|
+
_register_meta()
|
|
30
|
+
|
|
31
|
+
from torch.fx.node import has_side_effect
|
|
32
|
+
has_side_effect(torch.ops.my_ns.acl_save.default)
|
|
33
|
+
|
|
34
|
+
def acl_save(x: torch.Tensor, path: str) -> torch.Tensor:
|
|
35
|
+
"""
|
|
36
|
+
acl_save(tensor, path) -> tensor
|
|
37
|
+
|
|
38
|
+
Copy tensor to CPU and save to a .pt file.
|
|
39
|
+
The file name is generated as {base}_{seq}.pt in the same directory.
|
|
40
|
+
For NPU input, the save runs on the current NPU stream; synchronize if needed.
|
|
41
|
+
"""
|
|
42
|
+
return torch.ops.my_ns.acl_save(x, path)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
__all__ = ["acl_save"]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# This file is part of the MindStudio project.
|
|
3
|
+
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
|
|
4
|
+
#
|
|
5
|
+
# MindStudio is licensed under Mulan PSL v2.
|
|
6
|
+
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
|
7
|
+
# You may obtain a copy of Mulan PSL v2 at:
|
|
8
|
+
#
|
|
9
|
+
# http://license.coscl.org.cn/MulanPSL2
|
|
10
|
+
#
|
|
11
|
+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
|
12
|
+
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
|
13
|
+
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
|
14
|
+
# See the Mulan PSL v2 for more details.
|
|
15
|
+
# -------------------------------------------------------------------------
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
def _register_meta():
|
|
21
|
+
try:
|
|
22
|
+
@torch.library.register_fake("my_ns.acl_save")
|
|
23
|
+
def _fake_acl_save(x: torch.Tensor, path: str):
|
|
24
|
+
return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device="meta")
|
|
25
|
+
except Exception:
|
|
26
|
+
pass
|
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
# -*- coding: utf-8 -*-
|
|
3
|
-
#
|
|
4
|
-
#
|
|
3
|
+
# -------------------------------------------------------------------------
|
|
4
|
+
# This file is part of the MindStudio project.
|
|
5
|
+
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
|
|
5
6
|
#
|
|
6
|
-
#
|
|
7
|
-
#
|
|
8
|
-
# You may obtain a copy of
|
|
7
|
+
# MindStudio is licensed under Mulan PSL v2.
|
|
8
|
+
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
|
9
|
+
# You may obtain a copy of Mulan PSL v2 at:
|
|
9
10
|
#
|
|
10
|
-
#
|
|
11
|
+
# http://license.coscl.org.cn/MulanPSL2
|
|
11
12
|
#
|
|
12
|
-
#
|
|
13
|
-
#
|
|
14
|
-
#
|
|
15
|
-
# See the
|
|
16
|
-
#
|
|
13
|
+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
|
14
|
+
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
|
15
|
+
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
|
16
|
+
# See the Mulan PSL v2 for more details.
|
|
17
|
+
# -------------------------------------------------------------------------
|
|
17
18
|
|
|
18
19
|
import argparse
|
|
19
20
|
import os
|
|
@@ -35,9 +36,9 @@ else:
|
|
|
35
36
|
import torch
|
|
36
37
|
from tqdm import tqdm
|
|
37
38
|
|
|
38
|
-
from msprobe.pytorch.api_accuracy_checker.
|
|
39
|
+
from msprobe.pytorch.api_accuracy_checker.acc_check.acc_check_utils import BackwardMessage, UtDataInfo, \
|
|
39
40
|
get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info, is_unsupported_api
|
|
40
|
-
from msprobe.pytorch.api_accuracy_checker.
|
|
41
|
+
from msprobe.pytorch.api_accuracy_checker.acc_check.data_generate import gen_api_params, gen_args
|
|
41
42
|
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
42
43
|
initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
|
|
43
44
|
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
@@ -47,11 +48,12 @@ from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
|
47
48
|
from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
48
49
|
create_directory, get_json_contents, read_csv, check_file_or_directory_path
|
|
49
50
|
from msprobe.pytorch.common.log import logger
|
|
50
|
-
from msprobe.pytorch.pt_config import parse_json_config
|
|
51
|
+
from msprobe.pytorch.dump.pt_config import parse_json_config
|
|
51
52
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
53
|
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
54
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
|
-
from msprobe.pytorch.api_accuracy_checker.
|
|
55
|
+
from msprobe.pytorch.api_accuracy_checker.acc_check.acc_check_utils import generate_cpu_params,\
|
|
56
|
+
generate_device_params, \
|
|
55
57
|
ExecParams
|
|
56
58
|
|
|
57
59
|
|
|
@@ -86,15 +88,15 @@ tqdm_params = {
|
|
|
86
88
|
seed_all()
|
|
87
89
|
|
|
88
90
|
|
|
89
|
-
def
|
|
90
|
-
logger.info("start
|
|
91
|
+
def acc_check(config):
|
|
92
|
+
logger.info("start acc_check test")
|
|
91
93
|
|
|
92
|
-
logger.info(f"
|
|
93
|
-
logger.info(f"
|
|
94
|
+
logger.info(f"acc_check task result will be saved in {config.result_csv_path}")
|
|
95
|
+
logger.info(f"acc_check task details will be saved in {config.details_csv_path}")
|
|
94
96
|
|
|
95
97
|
if config.save_error_data:
|
|
96
|
-
logger.info(f"
|
|
97
|
-
compare = Comparator(config.result_csv_path, config.details_csv_path, config.
|
|
98
|
+
logger.info(f"acc_check task error_data will be saved in {config.error_data_path}")
|
|
99
|
+
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_acc_check, config=config)
|
|
98
100
|
|
|
99
101
|
|
|
100
102
|
csv_df = read_csv(config.result_csv_path)
|
|
@@ -107,8 +109,8 @@ def run_ut(config):
|
|
|
107
109
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
108
110
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
109
111
|
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
110
|
-
logger.info(f"
|
|
111
|
-
logger.info(f"
|
|
112
|
+
logger.info(f"acc_check task result csv is saved in {result_csv_path}")
|
|
113
|
+
logger.info(f"acc_check task details csv is saved in {details_csv_path}")
|
|
112
114
|
compare.print_pretest_result()
|
|
113
115
|
|
|
114
116
|
|
|
@@ -119,13 +121,13 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
119
121
|
if api_full_name in api_name_set:
|
|
120
122
|
continue
|
|
121
123
|
if is_unsupported_api(api_full_name):
|
|
122
|
-
skip_message = f"API {api_full_name} not support for
|
|
124
|
+
skip_message = f"API {api_full_name} not support for acc_check. SKIP."
|
|
123
125
|
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
|
|
124
126
|
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
125
127
|
continue
|
|
126
128
|
_, api_name = extract_basic_api_segments(api_full_name)
|
|
127
129
|
if not api_name:
|
|
128
|
-
err_message = f"API {api_full_name} not support for
|
|
130
|
+
err_message = f"API {api_full_name} not support for acc_check. SKIP."
|
|
129
131
|
logger.error(err_message)
|
|
130
132
|
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
|
|
131
133
|
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
@@ -146,7 +148,7 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
146
148
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
147
149
|
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
148
150
|
else:
|
|
149
|
-
logger.error(f"Run {api_full_name}
|
|
151
|
+
logger.error(f"Run {api_full_name} acc_check Error: %s" % str(err))
|
|
150
152
|
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
|
|
151
153
|
record_skip_info(api_full_name, compare, compare_alg_results)
|
|
152
154
|
finally:
|
|
@@ -208,7 +210,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
208
210
|
device_info_kwargs = kwargs.get(Const.DEVICE)
|
|
209
211
|
if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
|
|
210
212
|
kwargs[Const.DEVICE] = current_device
|
|
211
|
-
device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
|
|
213
|
+
device_args, device_kwargs, is_fp8 = generate_device_params(args, kwargs, need_backward, api_name)
|
|
212
214
|
if kwargs.get(Const.DEVICE):
|
|
213
215
|
del kwargs[Const.DEVICE]
|
|
214
216
|
cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name)
|
|
@@ -223,6 +225,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
223
225
|
device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast,
|
|
224
226
|
autocast_dtype)
|
|
225
227
|
device_out = exec_api(device_exec_params)
|
|
228
|
+
if is_fp8 and isinstance(device_out, torch.Tensor) and device_out.dtype == torch.float32:
|
|
229
|
+
device_out = device_out.to(torch.float16)
|
|
226
230
|
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
227
231
|
ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
|
|
228
232
|
api_setting_dict = get_json_contents(ut_setting_path)
|
|
@@ -251,7 +255,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
251
255
|
out = safe_get_value(out, 0, "out")
|
|
252
256
|
device_out = safe_get_value(device_out, 0, "device_out")
|
|
253
257
|
|
|
254
|
-
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message
|
|
258
|
+
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message,
|
|
259
|
+
rank=0, is_fp8=is_fp8)
|
|
255
260
|
|
|
256
261
|
|
|
257
262
|
def check_need_grad(api_info_dict):
|
|
@@ -300,7 +305,7 @@ def extract_tensors_grad(args, depth=0):
|
|
|
300
305
|
if isinstance(arg, torch.Tensor):
|
|
301
306
|
grads.append(arg.grad)
|
|
302
307
|
elif isinstance(arg, (list, tuple)):
|
|
303
|
-
grads.extend(extract_tensors_grad(arg, depth+1))
|
|
308
|
+
grads.extend(extract_tensors_grad(arg, depth + 1))
|
|
304
309
|
return grads
|
|
305
310
|
|
|
306
311
|
|
|
@@ -313,13 +318,13 @@ def initialize_save_error_data(error_data_path):
|
|
|
313
318
|
return error_data_path
|
|
314
319
|
|
|
315
320
|
|
|
316
|
-
def
|
|
321
|
+
def _acc_check_parser(parser):
|
|
317
322
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
|
|
318
323
|
help="<Optional> The api param tool result file: generate from api param tool, "
|
|
319
324
|
"a json file.",
|
|
320
325
|
required=False)
|
|
321
326
|
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
322
|
-
help="<optional> The
|
|
327
|
+
help="<optional> The acc_check task result out path.",
|
|
323
328
|
required=False)
|
|
324
329
|
parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
|
|
325
330
|
help="<optional> Save compare failed api output.", required=False)
|
|
@@ -337,11 +342,11 @@ def _run_ut_parser(parser):
|
|
|
337
342
|
setattr(namespace, self.dest, values)
|
|
338
343
|
|
|
339
344
|
parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
|
|
340
|
-
help="<optional> set device id to
|
|
345
|
+
help="<optional> set device id to acc_check, must be unique and in range 0-7",
|
|
341
346
|
default=[0], required=False, action=UniqueDeviceAction)
|
|
342
347
|
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str,
|
|
343
348
|
help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
|
|
344
|
-
"when
|
|
349
|
+
"when acc_check is interrupted, enter the file path to continue acc_check.",
|
|
345
350
|
required=False)
|
|
346
351
|
parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
|
|
347
352
|
help="<optional> Whether to filter the api in the api_info_file.", required=False)
|
|
@@ -389,27 +394,27 @@ def preprocess_forward_content(forward_content):
|
|
|
389
394
|
return processed_content
|
|
390
395
|
|
|
391
396
|
|
|
392
|
-
def
|
|
397
|
+
def _acc_check(parser=None):
|
|
393
398
|
if not parser:
|
|
394
399
|
parser = argparse.ArgumentParser()
|
|
395
|
-
|
|
400
|
+
_acc_check_parser(parser)
|
|
396
401
|
args = parser.parse_args(sys.argv[1:])
|
|
397
|
-
|
|
402
|
+
acc_check_command(args)
|
|
398
403
|
|
|
399
404
|
|
|
400
|
-
def
|
|
405
|
+
def acc_check_command(args):
|
|
401
406
|
if args.config_path:
|
|
402
407
|
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
403
408
|
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
404
409
|
checked_config_path = config_path_checker.common_check()
|
|
405
|
-
_, task_config = parse_json_config(checked_config_path, Const.
|
|
410
|
+
_, task_config = parse_json_config(checked_config_path, Const.ACC_CHECK)
|
|
406
411
|
checker_config = CheckerConfig(task_config)
|
|
407
412
|
else:
|
|
408
413
|
checker_config = CheckerConfig()
|
|
409
414
|
|
|
410
415
|
if not args.api_info_file:
|
|
411
|
-
logger.error("Please provide api_info_file for offline
|
|
412
|
-
raise Exception("Please provide api_info_file for offline
|
|
416
|
+
logger.error("Please provide api_info_file for offline acc_check.")
|
|
417
|
+
raise Exception("Please provide api_info_file for offline acc_check.")
|
|
413
418
|
|
|
414
419
|
if not is_gpu:
|
|
415
420
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
@@ -476,14 +481,14 @@ def run_ut_command(args):
|
|
|
476
481
|
'result_csv_path': result_csv_path,
|
|
477
482
|
'details_csv_path': details_csv_path,
|
|
478
483
|
'save_error_data': save_error_data,
|
|
479
|
-
'
|
|
484
|
+
'is_continue_acc_check': args.result_csv_path,
|
|
480
485
|
'real_data_path': real_data_path,
|
|
481
486
|
'error_data_path': error_data_path
|
|
482
487
|
}
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
logger.info("
|
|
488
|
+
acc_check_config = checker_config.get_acc_check_config(**config_params)
|
|
489
|
+
acc_check(acc_check_config)
|
|
490
|
+
logger.info("acc_check task completed.")
|
|
486
491
|
|
|
487
492
|
|
|
488
493
|
if __name__ == '__main__':
|
|
489
|
-
|
|
494
|
+
_acc_check()
|
msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut_utils.py → acc_check/acc_check_utils.py}
RENAMED
|
@@ -1,38 +1,41 @@
|
|
|
1
|
-
#
|
|
2
|
-
#
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# This file is part of the MindStudio project.
|
|
3
|
+
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
|
|
3
4
|
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
# You may obtain a copy of
|
|
5
|
+
# MindStudio is licensed under Mulan PSL v2.
|
|
6
|
+
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
|
7
|
+
# You may obtain a copy of Mulan PSL v2 at:
|
|
7
8
|
#
|
|
8
|
-
#
|
|
9
|
+
# http://license.coscl.org.cn/MulanPSL2
|
|
9
10
|
#
|
|
10
|
-
#
|
|
11
|
-
#
|
|
12
|
-
#
|
|
13
|
-
# See the
|
|
14
|
-
#
|
|
11
|
+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
|
12
|
+
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
|
13
|
+
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
|
14
|
+
# See the Mulan PSL v2 for more details.
|
|
15
|
+
# -------------------------------------------------------------------------
|
|
15
16
|
|
|
16
17
|
import os
|
|
17
18
|
from collections import namedtuple
|
|
18
19
|
import re
|
|
19
|
-
|
|
20
|
+
import numpy as np
|
|
20
21
|
import torch
|
|
21
22
|
try:
|
|
22
23
|
import torch_npu
|
|
23
24
|
except ImportError:
|
|
24
25
|
current_device = "cuda"
|
|
25
26
|
from torch.cuda.amp import autocast
|
|
27
|
+
IS_GPU = True
|
|
26
28
|
else:
|
|
27
29
|
current_device = "npu"
|
|
28
30
|
from torch_npu.npu.amp import autocast
|
|
31
|
+
IS_GPU = False
|
|
29
32
|
|
|
30
33
|
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
31
34
|
from msprobe.core.common.file_utils import FileChecker
|
|
32
35
|
from msprobe.core.common.log import logger
|
|
33
36
|
from msprobe.core.common.utils import CompareException
|
|
34
|
-
from msprobe.pytorch.
|
|
35
|
-
from msprobe.pytorch.
|
|
37
|
+
from msprobe.pytorch.dump.api_dump.api_register import ApiTemplate, get_api_register
|
|
38
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8, is_hifloat8_tensor
|
|
36
39
|
|
|
37
40
|
|
|
38
41
|
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
@@ -61,7 +64,7 @@ class BackwardMessage:
|
|
|
61
64
|
|
|
62
65
|
class UtDataInfo:
|
|
63
66
|
def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
|
|
64
|
-
backward_message, rank=0):
|
|
67
|
+
backward_message, rank=0, is_fp8=False):
|
|
65
68
|
self.bench_grad = bench_grad
|
|
66
69
|
self.device_grad = device_grad
|
|
67
70
|
self.device_output = device_output
|
|
@@ -70,6 +73,7 @@ class UtDataInfo:
|
|
|
70
73
|
self.in_fwd_data_list = in_fwd_data_list
|
|
71
74
|
self.backward_message = backward_message
|
|
72
75
|
self.rank = rank
|
|
76
|
+
self.is_fp8 = is_fp8
|
|
73
77
|
|
|
74
78
|
|
|
75
79
|
def get_validated_result_csv_path(result_csv_path, mode):
|
|
@@ -82,7 +86,7 @@ def get_validated_result_csv_path(result_csv_path, mode):
|
|
|
82
86
|
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
83
87
|
pattern = r"^accuracy_checking_result_\d{14}\.csv$"
|
|
84
88
|
if not re.match(pattern, result_csv_name):
|
|
85
|
-
raise ValueError("When continue
|
|
89
|
+
raise ValueError("When continue acc_check, please do not modify the result csv name.")
|
|
86
90
|
return validated_result_csv_path
|
|
87
91
|
|
|
88
92
|
|
|
@@ -117,15 +121,12 @@ def exec_api(exec_params):
|
|
|
117
121
|
):
|
|
118
122
|
return out
|
|
119
123
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
api_register.initialize_hook(None)
|
|
125
|
-
api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
|
|
126
|
-
api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
|
|
124
|
+
api_register = get_api_register()
|
|
125
|
+
api_register.initialize_hook(None)
|
|
126
|
+
api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
|
|
127
|
+
api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
|
|
127
128
|
|
|
128
|
-
|
|
129
|
+
torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
|
|
129
130
|
if is_autocast:
|
|
130
131
|
with autocast(dtype=autocast_dtype):
|
|
131
132
|
out = torch_api.forward(*args, **kwargs)
|
|
@@ -148,6 +149,10 @@ def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
|
|
|
148
149
|
输出:
|
|
149
150
|
arg: 转换dtype的标杆输入
|
|
150
151
|
'''
|
|
152
|
+
if is_hifloat8_tensor(arg):
|
|
153
|
+
return hif8_to_fp32(arg)
|
|
154
|
+
if is_dtype_fp8(arg.dtype):
|
|
155
|
+
return fp8_to_fp32(arg)
|
|
151
156
|
if api_name in hf_32_standard_api and arg.dtype == torch.float32:
|
|
152
157
|
return arg
|
|
153
158
|
if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype:
|
|
@@ -156,13 +161,18 @@ def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
|
|
|
156
161
|
|
|
157
162
|
|
|
158
163
|
def generate_device_params(input_args, input_kwargs, need_backward, api_name):
|
|
164
|
+
is_fp8 = False
|
|
165
|
+
|
|
159
166
|
def recursive_arg_to_device(arg_in, to_detach, depth=0):
|
|
167
|
+
nonlocal is_fp8
|
|
160
168
|
if depth > Const.MAX_DEPTH:
|
|
161
169
|
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
162
170
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
163
171
|
if isinstance(arg_in, (list, tuple)):
|
|
164
|
-
return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in)
|
|
172
|
+
return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth + 1) for arg in arg_in)
|
|
165
173
|
elif isinstance(arg_in, torch.Tensor):
|
|
174
|
+
if is_dtype_fp8(arg_in.dtype) or is_hifloat8_tensor(arg_in):
|
|
175
|
+
is_fp8 = True
|
|
166
176
|
if need_backward and arg_in.requires_grad:
|
|
167
177
|
arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
|
|
168
178
|
temp_arg_in = arg_in * 1
|
|
@@ -178,7 +188,7 @@ def generate_device_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
178
188
|
device_args = recursive_arg_to_device(input_args, is_detach)
|
|
179
189
|
device_kwargs = \
|
|
180
190
|
{key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
|
|
181
|
-
return device_args, device_kwargs
|
|
191
|
+
return device_args, device_kwargs, is_fp8
|
|
182
192
|
|
|
183
193
|
|
|
184
194
|
def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
@@ -187,7 +197,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
187
197
|
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
188
198
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
189
199
|
if isinstance(arg_in, (list, tuple)):
|
|
190
|
-
return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype, depth=depth+1)
|
|
200
|
+
return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype, depth=depth + 1)
|
|
191
201
|
for arg in arg_in)
|
|
192
202
|
elif isinstance(arg_in, torch.Tensor):
|
|
193
203
|
if need_backward and arg_in.requires_grad:
|
|
@@ -214,12 +224,12 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
214
224
|
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
215
225
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
216
226
|
if isinstance(arg_in, (list, tuple)):
|
|
217
|
-
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
|
|
227
|
+
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth + 1) for
|
|
218
228
|
arg in arg_in))
|
|
219
229
|
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
220
230
|
return set([arg_in.dtype])
|
|
221
231
|
elif isinstance(arg_in, dict) and check_kwargs:
|
|
222
|
-
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
|
|
232
|
+
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth + 1) for
|
|
223
233
|
v in arg_in.values()))
|
|
224
234
|
return set()
|
|
225
235
|
|
|
@@ -258,5 +268,166 @@ def is_unsupported_api(api_name, is_overflow_check=False):
|
|
|
258
268
|
unsupport_type_list = [Const.DISTRIBUTED, Const.MINDSPEED_API_TYPE_PREFIX]
|
|
259
269
|
flag = (split_name in unsupport_type_list) or (is_overflow_check and split_name == Const.NPU)
|
|
260
270
|
if flag:
|
|
261
|
-
logger.info(f"{split_name} api is not supported for
|
|
271
|
+
logger.info(f"{split_name} api is not supported for acc_check. SKIP.")
|
|
262
272
|
return flag
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def fp8_to_fp32(x):
|
|
276
|
+
"""
|
|
277
|
+
将FP8格式的张量转换为FP32格式,保持原始FP8的表示范围
|
|
278
|
+
使用纯PyTorch操作替代NumPy位运算
|
|
279
|
+
|
|
280
|
+
参数:
|
|
281
|
+
x: 输入的FP8张量,可以是torch.float8_e4m3fn或torch.float8_e5m2类型
|
|
282
|
+
|
|
283
|
+
返回:
|
|
284
|
+
torch.Tensor: 转换后的FP32张量,保持原始FP8的表示范围
|
|
285
|
+
"""
|
|
286
|
+
if x.dtype == torch.float8_e4m3fn:
|
|
287
|
+
# E4M3FN格式:1符号+4指数+3尾数,偏置7
|
|
288
|
+
# 位布局:SEEEEEMM
|
|
289
|
+
|
|
290
|
+
# 将FP8值视为无符号整数进行位操作
|
|
291
|
+
x_int = x.view(torch.uint8)
|
|
292
|
+
|
|
293
|
+
# 提取符号位、指数位和尾数位
|
|
294
|
+
sign_bits = (x_int & 0x80) >> 7 # 最高位是符号位
|
|
295
|
+
exp_bits = (x_int & 0x78) >> 3 # 接下来4位是指数位
|
|
296
|
+
mantissa_bits = x_int & 0x07 # 最后3位是尾数位
|
|
297
|
+
|
|
298
|
+
# 处理规格化数和非规格化数
|
|
299
|
+
is_normal = exp_bits != 0
|
|
300
|
+
|
|
301
|
+
# 计算FP32的指数部分(偏置127)
|
|
302
|
+
fp32_exp = torch.where(
|
|
303
|
+
is_normal,
|
|
304
|
+
(exp_bits - 7 + 127).to(torch.int32), # 规格化数:指数 = 原始指数 + 120
|
|
305
|
+
torch.tensor(0, dtype=torch.int32, device=x.device) # 非规格化数:指数为0
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# 计算FP32的尾数部分
|
|
309
|
+
# 规格化数:隐含1,尾数 = 1.0 + 原始尾数 * 2^(-3)
|
|
310
|
+
# 非规格化数:无隐含1,尾数 = 0.0 + 原始尾数 * 2^(-3)
|
|
311
|
+
fp32_mantissa = torch.where(
|
|
312
|
+
is_normal,
|
|
313
|
+
1.0 + mantissa_bits.to(torch.float32) / 8.0, # 2^(-3) = 1/8
|
|
314
|
+
mantissa_bits.to(torch.float32) / 8.0
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# 计算符号值 (-1)^sign
|
|
318
|
+
sign_value = torch.pow(-1.0, sign_bits.to(torch.float32))
|
|
319
|
+
|
|
320
|
+
# 计算最终FP32值
|
|
321
|
+
# 规格化数:value = (-1)^sign * (1.0 + mantissa/8) * 2^(exp - 7)
|
|
322
|
+
# 非规格化数:value = (-1)^sign * (mantissa/8) * 2^(-6)
|
|
323
|
+
fp32_result = sign_value * fp32_mantissa * torch.pow(2.0, fp32_exp - 127)
|
|
324
|
+
|
|
325
|
+
return fp32_result
|
|
326
|
+
|
|
327
|
+
elif x.dtype == torch.float8_e5m2:
|
|
328
|
+
# E5M2格式:1符号+5指数+2尾数,偏置15
|
|
329
|
+
# 位布局:SEEEEEEM
|
|
330
|
+
|
|
331
|
+
# 将FP8值视为无符号整数进行位操作
|
|
332
|
+
x_int = x.view(torch.uint8)
|
|
333
|
+
|
|
334
|
+
# 提取符号位、指数位和尾数位
|
|
335
|
+
sign_bits = (x_int & 0x80) >> 7 # 最高位是符号位
|
|
336
|
+
exp_bits = (x_int & 0x7C) >> 2 # 接下来5位是指数位
|
|
337
|
+
mantissa_bits = x_int & 0x03 # 最后2位是尾数位
|
|
338
|
+
|
|
339
|
+
# 处理规格化数和非规格化数
|
|
340
|
+
is_normal = exp_bits != 0
|
|
341
|
+
|
|
342
|
+
# 计算FP32的指数部分(偏置127)
|
|
343
|
+
fp32_exp = torch.where(
|
|
344
|
+
is_normal,
|
|
345
|
+
(exp_bits - 15 + 127).to(torch.int32), # 规格化数:指数 = 原始指数 + 112
|
|
346
|
+
torch.tensor(0, dtype=torch.int32, device=x.device) # 非规格化数:指数为0
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# 计算FP32的尾数部分
|
|
350
|
+
# 规格化数:隐含1,尾数 = 1.0 + 原始尾数 * 2^(-2)
|
|
351
|
+
# 非规格化数:无隐含1,尾数 = 0.0 + 原始尾数 * 2^(-2)
|
|
352
|
+
fp32_mantissa = torch.where(
|
|
353
|
+
is_normal,
|
|
354
|
+
1.0 + mantissa_bits.to(torch.float32) / 4.0, # 2^(-2) = 1/4
|
|
355
|
+
mantissa_bits.to(torch.float32) / 4.0
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# 计算符号值 (-1)^sign
|
|
359
|
+
sign_value = torch.pow(-1.0, sign_bits.to(torch.float32))
|
|
360
|
+
|
|
361
|
+
# 计算最终FP32值
|
|
362
|
+
fp32_result = sign_value * fp32_mantissa * torch.pow(2.0, fp32_exp - 127)
|
|
363
|
+
|
|
364
|
+
return fp32_result
|
|
365
|
+
|
|
366
|
+
else:
|
|
367
|
+
raise ValueError(f"Unsupported dtype: {x.dtype}. Expected torch.float8_e4m3fn or torch.float8_e5m2.")
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def hif8_to_fp32(x):
|
|
371
|
+
"""
|
|
372
|
+
将HiFloat8格式的张量转换为FP32格式,保持原始HiFloat8的表示范围
|
|
373
|
+
使用纯PyTorch操作替代NumPy位运算
|
|
374
|
+
|
|
375
|
+
参数:
|
|
376
|
+
x: 输入的HiFloat8张量,可以是torch_npu.HiFloat8Tensor类型
|
|
377
|
+
|
|
378
|
+
返回:
|
|
379
|
+
torch.Tensor: 转换后的FP32张量,保持原始HiFloat8的表示范围
|
|
380
|
+
"""
|
|
381
|
+
requires_grad = x.requires_grad
|
|
382
|
+
x = x.cpu().detach().numpy()
|
|
383
|
+
x = np.array(x) # 确保输入是numpy数组
|
|
384
|
+
|
|
385
|
+
# 创建结果数组,保持与输入相同的形状
|
|
386
|
+
res = np.zeros_like(x, dtype=np.float32)
|
|
387
|
+
|
|
388
|
+
# 获取输入张量的所有维度
|
|
389
|
+
dimensions = x.shape
|
|
390
|
+
# 计算总元素数量
|
|
391
|
+
total_elements = np.prod(dimensions)
|
|
392
|
+
|
|
393
|
+
# 遍历每个元素
|
|
394
|
+
for idx in range(total_elements):
|
|
395
|
+
# 将一维索引转换为多维索引
|
|
396
|
+
multi_indices = np.unravel_index(idx, dimensions)
|
|
397
|
+
z = x[multi_indices]
|
|
398
|
+
|
|
399
|
+
# 处理特殊值
|
|
400
|
+
if np.isnan(z) or np.isinf(z):
|
|
401
|
+
res[multi_indices] = z
|
|
402
|
+
continue
|
|
403
|
+
|
|
404
|
+
# 提取符号位
|
|
405
|
+
s = 1.0 if z >= 0 else -1.0
|
|
406
|
+
tmp = abs(z)
|
|
407
|
+
|
|
408
|
+
# 处理零值
|
|
409
|
+
if tmp == 0:
|
|
410
|
+
res[multi_indices] = 0.0
|
|
411
|
+
continue
|
|
412
|
+
|
|
413
|
+
# 确定指数范围和尾数位数
|
|
414
|
+
exponent = np.floor(np.log2(tmp + 1e-100)) # 添加小常量避免log2(0)
|
|
415
|
+
eabs = abs(exponent)
|
|
416
|
+
|
|
417
|
+
# 根据指数范围确定尾数位数和还原规则
|
|
418
|
+
if eabs <= 3: # 3-bit Mantissa
|
|
419
|
+
mantissa = (tmp / (2.0 ** exponent)) * 8.0 # 还原尾数部分
|
|
420
|
+
res[multi_indices] = s * (mantissa / 8.0) * (2.0 ** exponent)
|
|
421
|
+
elif eabs <= 7: # 2-bit Mantissa
|
|
422
|
+
mantissa = (tmp / (2.0 ** exponent)) * 4.0
|
|
423
|
+
res[multi_indices] = s * (mantissa / 4.0) * (2.0 ** exponent)
|
|
424
|
+
elif eabs <= 15: # 1-bit Mantissa
|
|
425
|
+
mantissa = (tmp / (2.0 ** exponent)) * 2.0
|
|
426
|
+
res[multi_indices] = s * (mantissa / 2.0) * (2.0 ** exponent)
|
|
427
|
+
else: # 0-bit Mantissa
|
|
428
|
+
res[multi_indices] = s * (2.0 ** exponent)
|
|
429
|
+
|
|
430
|
+
res = torch.from_numpy(res)
|
|
431
|
+
if requires_grad:
|
|
432
|
+
res = res.requires_grad_()
|
|
433
|
+
return res
|