mindstudio-probe 8.3.2__py3-none-any.whl → 26.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.2.dist-info → mindstudio_probe-26.0.0a1.dist-info}/METADATA +26 -14
- mindstudio_probe-26.0.0a1.dist-info/RECORD +498 -0
- {mindstudio_probe-8.3.2.dist-info → mindstudio_probe-26.0.0a1.dist-info}/WHEEL +1 -1
- mindstudio_probe-26.0.0a1.dist-info/entry_points.txt +5 -0
- mindstudio_probe-26.0.0a1.dist-info/licenses/LICENSE +124 -0
- mindstudio_probe-26.0.0a1.dist-info/top_level.txt +2 -0
- msprobe/__init__.py +12 -13
- msprobe/config.json +9 -31
- msprobe/core/__init__.py +12 -11
- msprobe/core/acc_check/acc_check_cli.py +145 -0
- msprobe/core/common/const.py +97 -38
- msprobe/core/common/db_manager.py +133 -12
- msprobe/core/common/decorator.py +12 -11
- msprobe/core/common/exceptions.py +12 -11
- msprobe/core/common/file_utils.py +101 -25
- msprobe/core/common/framework_adapter.py +36 -25
- msprobe/core/common/global_lock.py +12 -11
- msprobe/core/common/inplace_op_checker.py +12 -11
- msprobe/core/common/log.py +22 -11
- msprobe/core/common/megatron_utils.py +566 -11
- msprobe/core/common/parallel_state.py +12 -11
- msprobe/core/common/runtime.py +12 -11
- msprobe/core/common/utils.py +41 -41
- msprobe/core/compare/acc_compare.py +361 -104
- msprobe/core/compare/atb_data_compare.py +422 -0
- msprobe/core/compare/auto_compare.py +134 -0
- msprobe/core/compare/check.py +14 -17
- msprobe/core/compare/compare_cli.py +72 -149
- msprobe/core/compare/config.py +12 -13
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +28 -15
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/analyzer.py +18 -18
- msprobe/core/compare/find_first/graph.py +12 -11
- msprobe/core/compare/find_first/utils.py +13 -12
- msprobe/core/compare/indicator_analysis/__init__.py +15 -0
- msprobe/core/compare/indicator_analysis/algorithm.py +363 -0
- msprobe/core/compare/indicator_analysis/api_data.py +141 -0
- msprobe/core/compare/indicator_analysis/calculator.py +181 -0
- msprobe/core/compare/indicator_analysis/utils.py +116 -0
- msprobe/core/compare/layer_mapping/__init__.py +12 -11
- msprobe/core/compare/layer_mapping/data_scope_parser.py +20 -11
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -13
- msprobe/core/compare/layer_mapping/postprocess_pass.py +13 -11
- msprobe/core/compare/merge_result/merge_result.py +12 -11
- msprobe/core/compare/merge_result/merge_result_cli.py +12 -11
- msprobe/core/compare/merge_result/utils.py +12 -11
- msprobe/core/compare/multiprocessing_compute.py +13 -14
- msprobe/core/compare/npy_compare.py +13 -11
- msprobe/core/compare/offline_data_compare.py +160 -0
- msprobe/core/compare/stats_diff_calc.py +39 -0
- msprobe/core/compare/torchair_acc_cmp.py +764 -0
- msprobe/core/compare/torchair_cmp_utils.py +338 -0
- msprobe/core/compare/utils.py +140 -49
- msprobe/core/config_check/__init__.py +12 -11
- msprobe/core/config_check/checkers/__init__.py +12 -11
- msprobe/core/config_check/checkers/base_checker.py +15 -14
- msprobe/core/config_check/checkers/dataset_checker.py +13 -12
- msprobe/core/config_check/checkers/env_args_checker.py +13 -12
- msprobe/core/config_check/checkers/hyperparameter_checker.py +16 -15
- msprobe/core/config_check/checkers/pip_checker.py +15 -15
- msprobe/core/config_check/checkers/random_checker.py +13 -12
- msprobe/core/config_check/checkers/weights_checker.py +14 -12
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +13 -17
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +13 -12
- msprobe/core/config_check/ckpt_compare/metrics.py +12 -11
- msprobe/core/config_check/config_check_cli.py +18 -17
- msprobe/core/config_check/config_checker.py +16 -14
- msprobe/core/config_check/resource/dependency.yaml +15 -12
- msprobe/core/config_check/resource/env.yaml +12 -11
- msprobe/core/config_check/utils/hyperparameter_parser.py +12 -11
- msprobe/core/config_check/utils/utils.py +12 -11
- msprobe/core/{data_dump → dump/api_dump}/api_registry.py +12 -11
- msprobe/core/{common_config.py → dump/common_config.py} +13 -24
- msprobe/core/dump/data_dump/data_collector.py +257 -0
- msprobe/core/{data_dump → dump/data_dump}/data_processor/base.py +45 -36
- msprobe/core/{data_dump → dump/data_dump}/data_processor/factory.py +33 -25
- msprobe/core/{data_dump → dump/data_dump}/data_processor/mindspore_processor.py +37 -113
- msprobe/core/{data_dump → dump/data_dump}/data_processor/pytorch_processor.py +364 -131
- msprobe/core/{data_dump → dump/data_dump}/json_writer.py +24 -31
- msprobe/core/{data_dump → dump/data_dump}/scope.py +12 -13
- msprobe/core/{debugger → dump/debugger}/precision_debugger.py +15 -23
- msprobe/core/dump/dump2db/db_utils.py +215 -0
- msprobe/core/dump/dump2db/dump2db.py +409 -0
- msprobe/core/{hook_manager.py → dump/hook_manager.py} +38 -87
- msprobe/core/dump/kernel_dump/kernel_config.py +34 -0
- msprobe/core/{service.py → dump/service.py} +43 -27
- msprobe/core/install_deps/install_deps.py +51 -0
- msprobe/core/monitor/anomaly_processor.py +13 -11
- msprobe/core/monitor/csv2db.py +73 -93
- msprobe/core/monitor/db_utils.py +140 -205
- msprobe/core/monitor/utils.py +18 -17
- msprobe/core/monitor_v2/__init__.py +20 -0
- msprobe/core/monitor_v2/base.py +83 -0
- msprobe/core/monitor_v2/cc.py +287 -0
- msprobe/core/monitor_v2/factory.py +81 -0
- msprobe/core/monitor_v2/module.py +201 -0
- msprobe/core/monitor_v2/optimizer.py +245 -0
- msprobe/core/monitor_v2/param.py +154 -0
- msprobe/core/monitor_v2/trainer.py +326 -0
- msprobe/core/monitor_v2/utils.py +122 -0
- msprobe/core/monitor_v2/weight_grad.py +419 -0
- msprobe/core/monitor_v2/writer.py +162 -0
- msprobe/core/overflow_check/abnormal_scene.py +12 -11
- msprobe/core/overflow_check/api_info.py +12 -11
- msprobe/core/overflow_check/checker.py +12 -11
- msprobe/core/overflow_check/filter.py +13 -11
- msprobe/core/overflow_check/level.py +12 -11
- msprobe/core/overflow_check/utils.py +12 -11
- msprobe/core/single_save/single_comparator.py +12 -11
- msprobe/core/single_save/single_saver.py +12 -11
- msprobe/infer/__init__.py +16 -0
- msprobe/infer/offline/__init__.py +16 -0
- msprobe/infer/offline/compare/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/adapter_cli/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/adapter_cli/args_adapter.py +46 -0
- msprobe/infer/offline/compare/msquickcmp/atc/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/atc/atc_utils.py +98 -0
- msprobe/infer/offline/compare/msquickcmp/cmp_process.py +328 -0
- msprobe/infer/offline/compare/msquickcmp/common/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/common/args_check.py +112 -0
- msprobe/infer/offline/compare/msquickcmp/common/convert.py +74 -0
- msprobe/infer/offline/compare/msquickcmp/common/dump_data.py +121 -0
- msprobe/infer/offline/compare/msquickcmp/common/dynamic_argument_bean.py +39 -0
- msprobe/infer/offline/compare/msquickcmp/common/utils.py +669 -0
- msprobe/infer/offline/compare/msquickcmp/config.ini +6 -0
- msprobe/infer/offline/compare/msquickcmp/dump/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/dump/args_adapter.py +50 -0
- msprobe/infer/offline/compare/msquickcmp/dump/dump_process.py +91 -0
- msprobe/infer/offline/compare/msquickcmp/install_aclruntime_aisbench.sh +180 -0
- msprobe/infer/offline/compare/msquickcmp/main.py +199 -0
- msprobe/infer/offline/compare/msquickcmp/net_compare/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/net_compare/net_compare.py +277 -0
- msprobe/infer/offline/compare/msquickcmp/npu/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/npu/npu_dump_data.py +558 -0
- msprobe/infer/offline/compare/msquickcmp/npu/om_parser.py +416 -0
- msprobe/infer/offline/compare/msquickcmp/onnx_model/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/onnx_model/onnx_dump_data.py +374 -0
- msprobe/infer/utils/__init__.py +15 -0
- msprobe/infer/utils/acc_cmp.py +94 -0
- msprobe/infer/utils/check/__init__.py +37 -0
- msprobe/infer/utils/check/args_checker.py +35 -0
- msprobe/infer/utils/check/checker.py +227 -0
- msprobe/infer/utils/check/dict_checker.py +78 -0
- msprobe/infer/utils/check/func_wrapper.py +96 -0
- msprobe/infer/utils/check/list_checker.py +56 -0
- msprobe/infer/utils/check/number_checker.py +64 -0
- msprobe/infer/utils/check/obj_checker.py +41 -0
- msprobe/infer/utils/check/path_checker.py +249 -0
- msprobe/infer/utils/check/rule.py +126 -0
- msprobe/infer/utils/check/string_checker.py +66 -0
- msprobe/infer/utils/cmp_algorithm.py +261 -0
- msprobe/infer/utils/constants.py +112 -0
- msprobe/infer/utils/file_open_check.py +337 -0
- msprobe/infer/utils/util.py +177 -0
- msprobe/mindspore/__init__.py +14 -13
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +14 -13
- msprobe/mindspore/api_accuracy_checker/api_info.py +12 -11
- msprobe/mindspore/api_accuracy_checker/api_runner.py +12 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +12 -11
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +12 -11
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +12 -11
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +12 -11
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +15 -14
- msprobe/mindspore/api_accuracy_checker/compute_element.py +12 -11
- msprobe/mindspore/api_accuracy_checker/data_manager.py +13 -11
- msprobe/mindspore/api_accuracy_checker/main.py +12 -11
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +14 -12
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +13 -11
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +12 -11
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +12 -11
- msprobe/mindspore/api_accuracy_checker/utils.py +12 -11
- msprobe/mindspore/common/const.py +15 -74
- msprobe/mindspore/common/log.py +12 -11
- msprobe/mindspore/common/utils.py +30 -15
- msprobe/mindspore/compare/common_dir_compare.py +21 -23
- msprobe/mindspore/compare/distributed_compare.py +18 -16
- msprobe/mindspore/compare/ms_compare.py +14 -14
- msprobe/mindspore/compare/ms_graph_compare.py +26 -20
- msprobe/mindspore/compare/utils.py +14 -12
- msprobe/mindspore/{cell_processor.py → dump/cell_processor.py} +15 -14
- msprobe/mindspore/{debugger → dump/debugger}/debugger_config.py +12 -30
- msprobe/mindspore/{debugger → dump/debugger}/precision_debugger.py +43 -45
- msprobe/mindspore/dump/{cell_dump_process.py → dump_processor/cell_dump_process.py} +31 -17
- msprobe/mindspore/dump/{cell_dump_with_insert_gradient.py → dump_processor/cell_dump_with_insert_gradient.py} +18 -14
- msprobe/mindspore/dump/{dump_tool_factory.py → dump_processor/dump_tool_factory.py} +16 -15
- msprobe/mindspore/dump/{graph_mode_cell_dump.py → dump_processor/graph_mode_cell_dump.py} +16 -15
- msprobe/mindspore/dump/{graph_tensor_dump.py → dump_processor/graph_tensor_dump.py} +134 -133
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/api_register.py +15 -14
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/hook_cell.py +12 -11
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/ms_hook_manager.py +47 -20
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/primitive_hooks.py +14 -13
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/support_wrap_ops.yaml +13 -11
- msprobe/mindspore/dump/{jit_dump.py → dump_processor/jit_dump.py} +14 -13
- msprobe/mindspore/dump/{kernel_graph_dump.py → dump_processor/kernel_graph_dump.py} +13 -12
- msprobe/mindspore/dump/{kernel_kbyk_dump.py → dump_processor/kernel_kbyk_dump.py} +13 -12
- msprobe/mindspore/{exception_dump → dump/exception_dump}/exception_dump_tool_factory.py +14 -13
- msprobe/mindspore/{exception_dump → dump/exception_dump}/kernel_graph_exception_dump.py +13 -12
- msprobe/mindspore/{mindspore_service.py → dump/mindspore_service.py} +18 -17
- msprobe/mindspore/dump/mindtorch/__init__.py +19 -0
- msprobe/mindspore/dump/ms_config.py +105 -0
- msprobe/mindspore/{overflow_check → dump/overflow_check}/kernel_graph_overflow_check.py +13 -12
- msprobe/mindspore/{overflow_check → dump/overflow_check}/overflow_check_tool_factory.py +14 -13
- msprobe/mindspore/dump/task_handler_factory.py +43 -0
- msprobe/mindspore/monitor/common_func.py +12 -11
- msprobe/mindspore/monitor/data_writers.py +12 -11
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +93 -39
- msprobe/mindspore/monitor/features.py +12 -11
- msprobe/mindspore/monitor/module_hook.py +19 -22
- msprobe/mindspore/monitor/optimizer_collect.py +29 -25
- msprobe/mindspore/monitor/utils.py +13 -11
- msprobe/msaccucmp/advisor/__init__.py +16 -0
- msprobe/msaccucmp/advisor/advisor_const.py +65 -0
- msprobe/msaccucmp/advisor/advisor_result.py +73 -0
- msprobe/msaccucmp/advisor/compare_advisor.py +99 -0
- msprobe/msaccucmp/advisor/input_advisor.py +66 -0
- msprobe/msaccucmp/advisor/node_advisor.py +68 -0
- msprobe/msaccucmp/advisor/overflow_advisor.py +58 -0
- msprobe/msaccucmp/algorithm_manager/__init__.py +16 -0
- msprobe/msaccucmp/algorithm_manager/algorithm_manager.py +464 -0
- msprobe/msaccucmp/algorithm_manager/algorithm_parameter.py +42 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_AccumulatedRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_CosineSimilarity.py +58 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_KullbackLeiblerDivergence.py +84 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxAbsoluteError.py +41 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanAbsoluteError.py +41 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RelativeEuclideanDistance.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RootMeanSquareError.py +40 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_StandardDeviation.py +47 -0
- msprobe/msaccucmp/cmp_utils/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/common.py +113 -0
- msprobe/msaccucmp/cmp_utils/constant/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/constant/compare_error.py +81 -0
- msprobe/msaccucmp/cmp_utils/constant/const_manager.py +530 -0
- msprobe/msaccucmp/cmp_utils/file_utils.py +497 -0
- msprobe/msaccucmp/cmp_utils/log.py +257 -0
- msprobe/msaccucmp/cmp_utils/multi_process/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/multi_process/multi_convert_process.py +140 -0
- msprobe/msaccucmp/cmp_utils/multi_process/progress.py +78 -0
- msprobe/msaccucmp/cmp_utils/path_check.py +274 -0
- msprobe/msaccucmp/cmp_utils/reg_manager.py +98 -0
- msprobe/msaccucmp/cmp_utils/tlv_parse.py +279 -0
- msprobe/msaccucmp/cmp_utils/utils.py +356 -0
- msprobe/msaccucmp/cmp_utils/utils_type.py +63 -0
- msprobe/msaccucmp/compare_vector.py +48 -0
- msprobe/msaccucmp/conversion/__init__.py +16 -0
- msprobe/msaccucmp/conversion/data_conversion.py +277 -0
- msprobe/msaccucmp/conversion/dtype_conversion.py +99 -0
- msprobe/msaccucmp/conversion/shape_format_conversion.py +477 -0
- msprobe/msaccucmp/conversion/tensor_conversion.py +369 -0
- msprobe/msaccucmp/dump_data_conversion.py +46 -0
- msprobe/msaccucmp/dump_parse/__init__.py +16 -0
- msprobe/msaccucmp/dump_parse/big_dump_data.py +317 -0
- msprobe/msaccucmp/dump_parse/dump.py +423 -0
- msprobe/msaccucmp/dump_parse/dump_data_object.py +322 -0
- msprobe/msaccucmp/dump_parse/dump_data_parser.py +436 -0
- msprobe/msaccucmp/dump_parse/dump_utils.py +246 -0
- msprobe/msaccucmp/dump_parse/ffts_parser.py +137 -0
- msprobe/msaccucmp/dump_parse/mapping.py +62 -0
- msprobe/msaccucmp/dump_parse/nano_dump_data.py +392 -0
- msprobe/msaccucmp/dump_parse/proto_dump_data.py +308 -0
- msprobe/msaccucmp/dump_parser.py +90 -0
- msprobe/msaccucmp/format_manager/__init__.py +16 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NCHW.py +53 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_ND.py +52 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NHWC.py +53 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_HWCN.py +47 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_NCHW.py +47 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_FRACTAL_Z.py +89 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NCHW.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NHWC.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_HWCN.py +43 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NCHW.py +48 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NHWC.py +43 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_FRACTAL_Z.py +87 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_NHWC.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_NCDHW.py +48 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_ND.py +44 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_FRACTAL_Z.py +87 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_HWCN.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_NCHW.py +37 -0
- msprobe/msaccucmp/format_manager/format_manager.py +307 -0
- msprobe/msaccucmp/inplace_layer_process.py +186 -0
- msprobe/msaccucmp/msaccucmp.py +532 -0
- msprobe/msaccucmp/mscmp_advisor.py +128 -0
- msprobe/msaccucmp/overflow/__init__.py +16 -0
- msprobe/msaccucmp/overflow/overflow_analyse.py +305 -0
- msprobe/msaccucmp/overflow/overflow_detection.py +143 -0
- msprobe/msaccucmp/pytorch_cmp/__init__.py +16 -0
- msprobe/msaccucmp/pytorch_cmp/compare_pytorch.py +389 -0
- msprobe/msaccucmp/pytorch_cmp/hdf5_parser.py +377 -0
- msprobe/msaccucmp/pytorch_cmp/pytorch_dump_data.py +461 -0
- msprobe/msaccucmp/shape_conversion.py +41 -0
- msprobe/msaccucmp/vector_cmp/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/batch_compare.py +197 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/compare_detail.py +245 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/detail.py +182 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/detail_writer.py +580 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_fusion_op.py +588 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_npu_vs_npu.py +339 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_result.py +326 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_rule.py +156 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_op.py +204 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_rule_parser.py +635 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/quant_filter.py +187 -0
- msprobe/msaccucmp/vector_cmp/range_manager/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/range_manager/range_manager.py +100 -0
- msprobe/msaccucmp/vector_cmp/range_manager/range_mode.py +94 -0
- msprobe/msaccucmp/vector_cmp/range_manager/select_mode.py +86 -0
- msprobe/msaccucmp/vector_cmp/vector_comparison.py +535 -0
- msprobe/msprobe.py +101 -130
- msprobe/overflow_check/__init__.py +15 -0
- msprobe/{nan_analyze → overflow_check}/analyzer.py +38 -27
- msprobe/{nan_analyze → overflow_check}/graph.py +30 -27
- msprobe/{nan_analyze → overflow_check}/utils.py +15 -14
- msprobe/pytorch/__init__.py +20 -14
- msprobe/pytorch/aclgraph_dump/__init__.py +45 -0
- msprobe/pytorch/aclgraph_dump/_meta.py +26 -0
- msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut.py → acc_check/acc_check.py} +50 -45
- msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut_utils.py → acc_check/acc_check_utils.py} +201 -30
- msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/data_generate.py +56 -16
- msprobe/pytorch/api_accuracy_checker/{run_ut/multi_run_ut.py → acc_check/multi_acc_check.py} +32 -47
- msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/run_overflow_check.py +19 -18
- msprobe/pytorch/api_accuracy_checker/common/config.py +22 -20
- msprobe/pytorch/api_accuracy_checker/common/utils.py +72 -13
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -11
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +23 -14
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +45 -32
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +12 -11
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +14 -12
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +14 -12
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +21 -19
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +14 -13
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +60 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +27 -16
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +13 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +39 -18
- msprobe/pytorch/bench_functions/__init__.py +12 -11
- msprobe/pytorch/bench_functions/apply_adam.py +12 -11
- msprobe/pytorch/bench_functions/apply_adam_w.py +12 -11
- msprobe/pytorch/bench_functions/confusion_transpose.py +12 -11
- msprobe/pytorch/bench_functions/fast_gelu.py +12 -11
- msprobe/pytorch/bench_functions/group_norm_silu.py +12 -11
- msprobe/pytorch/bench_functions/layer_norm_eval.py +12 -11
- msprobe/pytorch/bench_functions/linear.py +12 -11
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -11
- msprobe/pytorch/bench_functions/mish.py +12 -11
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +12 -11
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +12 -11
- msprobe/pytorch/bench_functions/rms_norm.py +12 -11
- msprobe/pytorch/bench_functions/rotary_mul.py +12 -11
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +12 -11
- msprobe/pytorch/bench_functions/sort_v2.py +12 -11
- msprobe/pytorch/bench_functions/swiglu.py +12 -11
- msprobe/pytorch/common/__init__.py +12 -11
- msprobe/pytorch/common/log.py +12 -11
- msprobe/pytorch/common/parse_json.py +12 -11
- msprobe/pytorch/common/utils.py +52 -19
- msprobe/pytorch/compare/distributed_compare.py +13 -13
- msprobe/pytorch/compare/match.py +12 -11
- msprobe/pytorch/compare/pt_compare.py +14 -20
- msprobe/pytorch/compare/pt_diff_analyze.py +12 -11
- msprobe/pytorch/compare/utils.py +12 -11
- msprobe/pytorch/{hook_module → dump/api_dump}/api_register.py +18 -16
- msprobe/pytorch/{hook_module → dump/api_dump}/hook_module.py +14 -13
- msprobe/pytorch/{hook_module → dump/api_dump}/pt_hook_manager.py +68 -23
- msprobe/pytorch/{hook_module → dump/api_dump}/register_optimizer_hook.py +13 -11
- msprobe/pytorch/{hook_module → dump/api_dump}/script_wrapper.py +17 -14
- msprobe/pytorch/{hook_module → dump/api_dump}/utils.py +12 -11
- msprobe/pytorch/{debugger → dump/debugger}/debugger_config.py +23 -38
- msprobe/pytorch/dump/debugger/precision_debugger.py +130 -0
- msprobe/pytorch/{function_factory.py → dump/function_factory.py} +12 -11
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +17 -13
- msprobe/pytorch/dump/module_dump/module_dump.py +16 -15
- msprobe/pytorch/dump/module_dump/{module_processer.py → module_processor.py} +54 -42
- msprobe/pytorch/dump/pt_config.py +128 -0
- msprobe/pytorch/{pytorch_service.py → dump/pytorch_service.py} +22 -21
- msprobe/pytorch/monitor/csv2tb.py +13 -11
- msprobe/pytorch/monitor/data_writers.py +13 -11
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +13 -11
- msprobe/pytorch/monitor/features.py +12 -11
- msprobe/pytorch/monitor/module_hook.py +67 -59
- msprobe/pytorch/monitor/module_metric.py +13 -11
- msprobe/pytorch/monitor/optimizer_collect.py +37 -35
- msprobe/pytorch/monitor/utils.py +13 -11
- msprobe/pytorch/monitor/visualizer.py +12 -11
- msprobe/pytorch/torchair_dump/__init__.py +17 -0
- msprobe/pytorch/torchair_dump/torchair_dump.py +114 -0
- msprobe/scripts/atb/config_example.json +10 -0
- msprobe/scripts/atb/load_atb_probe.sh +101 -0
- msprobe/scripts/atb/unload_atb_probe.sh +27 -0
- msprobe/scripts/build_msaccucmp.sh +186 -0
- msprobe/scripts/conf/help.info +6 -0
- msprobe/scripts/conf/version.info +3 -0
- msprobe/scripts/run_script/common.sh +538 -0
- msprobe/scripts/run_script/main_msaccucmp.sh +232 -0
- msprobe/visualization/__init__.py +12 -11
- msprobe/visualization/builder/__init__.py +12 -11
- msprobe/visualization/builder/graph_builder.py +45 -30
- msprobe/visualization/builder/graph_merger.py +53 -32
- msprobe/visualization/builder/msprobe_adapter.py +34 -44
- msprobe/visualization/compare/__init__.py +12 -11
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +28 -113
- msprobe/visualization/db_utils.py +133 -22
- msprobe/visualization/graph/__init__.py +12 -11
- msprobe/visualization/graph/base_node.py +15 -27
- msprobe/visualization/graph/distributed_analyzer.py +97 -40
- msprobe/visualization/graph/graph.py +14 -16
- msprobe/visualization/graph/node_colors.py +34 -31
- msprobe/visualization/graph/node_op.py +12 -11
- msprobe/visualization/graph_service.py +580 -205
- msprobe/visualization/utils.py +278 -31
- tb_graph_ascend/secure_build.py +175 -0
- tb_graph_ascend/server/__init__.py +15 -0
- tb_graph_ascend/server/app/__init__.py +15 -0
- tb_graph_ascend/server/app/model/__init__.py +15 -0
- tb_graph_ascend/server/app/model/hierarchy.py +348 -0
- tb_graph_ascend/server/app/model/layout_hierarchy_model.py +69 -0
- tb_graph_ascend/server/app/model/match_nodes_model.py +573 -0
- tb_graph_ascend/server/app/repositories/__init__.py +15 -0
- tb_graph_ascend/server/app/repositories/graph_repo_base.py +32 -0
- tb_graph_ascend/server/app/repositories/graph_repo_db.py +879 -0
- tb_graph_ascend/server/app/repositories/graph_repo_vis.py +83 -0
- tb_graph_ascend/server/app/service/__init__.py +18 -0
- tb_graph_ascend/server/app/service/graph_service_base.py +158 -0
- tb_graph_ascend/server/app/service/graph_service_db.py +438 -0
- tb_graph_ascend/server/app/service/graph_service_factory.py +54 -0
- tb_graph_ascend/server/app/service/graph_service_vis.py +480 -0
- tb_graph_ascend/server/app/utils/__init__.py +15 -0
- tb_graph_ascend/server/app/utils/constant.py +80 -0
- tb_graph_ascend/server/app/utils/file_check_wrapper.py +46 -0
- tb_graph_ascend/server/app/utils/global_state.py +95 -0
- tb_graph_ascend/server/app/utils/graph_utils.py +661 -0
- tb_graph_ascend/server/app/utils/i18n.py +153 -0
- tb_graph_ascend/server/app/utils/request_method.py +46 -0
- tb_graph_ascend/server/app/views/__init__.py +15 -0
- tb_graph_ascend/server/app/views/graph_views.py +304 -0
- tb_graph_ascend/server/plugin.py +108 -0
- tb_graph_ascend/server/static/index.html +9250 -0
- tb_graph_ascend/server/static/index.js +21 -0
- tb_graph_ascend/setup.py +57 -0
- mindstudio_probe-8.3.2.dist-info/LICENSE +0 -201
- mindstudio_probe-8.3.2.dist-info/RECORD +0 -491
- mindstudio_probe-8.3.2.dist-info/entry_points.txt +0 -2
- mindstudio_probe-8.3.2.dist-info/top_level.txt +0 -1
- msprobe/CMakeLists.txt +0 -5
- msprobe/README.md +0 -203
- msprobe/core/advisor/advisor.py +0 -129
- msprobe/core/advisor/advisor_const.py +0 -58
- msprobe/core/advisor/advisor_result.py +0 -58
- msprobe/core/compare/find_first/data_processor.py +0 -35
- msprobe/core/compare/highlight.py +0 -390
- msprobe/core/data_dump/data_collector.py +0 -356
- msprobe/core/grad_probe/constant.py +0 -90
- msprobe/core/grad_probe/grad_compare.py +0 -187
- msprobe/core/grad_probe/utils.py +0 -105
- msprobe/core/kernel_dump/kernel_config.py +0 -33
- msprobe/docs/01.installation.md +0 -250
- msprobe/docs/02.config_introduction.md +0 -221
- msprobe/docs/03.config_examples.md +0 -281
- msprobe/docs/04.kernel_dump_PyTorch.md +0 -73
- msprobe/docs/05.data_dump_PyTorch.md +0 -518
- msprobe/docs/06.data_dump_MindSpore.md +0 -618
- msprobe/docs/07.accuracy_checker_PyTorch.md +0 -310
- msprobe/docs/09.accuracy_checker_MindSpore.md +0 -120
- msprobe/docs/10.accuracy_compare_PyTorch.md +0 -637
- msprobe/docs/11.accuracy_compare_MindSpore.md +0 -769
- msprobe/docs/12.overflow_check_PyTorch.md +0 -82
- msprobe/docs/13.overflow_check_MindSpore.md +0 -33
- msprobe/docs/14.data_parse_PyTorch.md +0 -282
- msprobe/docs/15.free_benchmarking_PyTorch.md +0 -169
- msprobe/docs/16.free_benchmarking_MindSpore.md +0 -159
- msprobe/docs/17.grad_probe.md +0 -205
- msprobe/docs/18.online_dispatch.md +0 -89
- msprobe/docs/19.monitor.md +0 -753
- msprobe/docs/20.monitor_performance_baseline.md +0 -52
- msprobe/docs/21.visualization_PyTorch.md +0 -519
- msprobe/docs/22.visualization_MindSpore.md +0 -515
- msprobe/docs/23.generate_operator_PyTorch.md +0 -107
- msprobe/docs/24.code_mapping_Mindspore.md +0 -29
- msprobe/docs/25.tool_function_introduction.md +0 -29
- msprobe/docs/26.data_dump_PyTorch_baseline.md +0 -48
- msprobe/docs/27.dump_json_instruction.md +0 -795
- msprobe/docs/28.debugger_save_instruction.md +0 -288
- msprobe/docs/28.kernel_dump_MindSpore.md +0 -69
- msprobe/docs/29.data_dump_MSAdapter.md +0 -235
- msprobe/docs/30.overflow_check_MSAdapter.md +0 -31
- msprobe/docs/31.config_check.md +0 -107
- msprobe/docs/32.ckpt_compare.md +0 -69
- msprobe/docs/33.generate_operator_MindSpore.md +0 -181
- msprobe/docs/34.RL_collect.md +0 -101
- msprobe/docs/35.nan_analyze.md +0 -73
- msprobe/docs/36.calculation_result_change.md +0 -75
- msprobe/docs/FAQ.md +0 -232
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +0 -146
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +0 -14
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +0 -33
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +0 -217
- msprobe/docs/img/BLOOM-7B_1.png +0 -0
- msprobe/docs/img/BLOOM-7B_2.png +0 -0
- msprobe/docs/img/BLOOM-7B_3.png +0 -0
- msprobe/docs/img/BLOOM-7B_4.png +0 -0
- msprobe/docs/img/GPT-3_1.png +0 -0
- msprobe/docs/img/GPT-3_2.png +0 -0
- msprobe/docs/img/GPT-3_3.png +0 -0
- msprobe/docs/img/GPT-3_4.png +0 -0
- msprobe/docs/img/GPT-3_5.png +0 -0
- msprobe/docs/img/GPT-3_6.png +0 -0
- msprobe/docs/img/GPT-3_7.png +0 -0
- msprobe/docs/img/GPT-3_8.png +0 -0
- msprobe/docs/img/YOLOV5S_1.png +0 -0
- msprobe/docs/img/YOLOV5S_2.png +0 -0
- msprobe/docs/img/accuracy_checking_details.png +0 -0
- msprobe/docs/img/accuracy_checking_result.png +0 -0
- msprobe/docs/img/api_precision_compare_details.png +0 -0
- msprobe/docs/img/api_precision_compare_result.png +0 -0
- msprobe/docs/img/auto_analyze_log.png +0 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/compare_result_pkl.png +0 -0
- msprobe/docs/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/docs/img/cpu_info.png +0 -0
- msprobe/docs/img/free_benchmark.png +0 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/module_compare.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +0 -132
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +0 -59
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +0 -80
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +0 -330
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +0 -460
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +0 -2081
- msprobe/mindspore/code_mapping/bind.py +0 -283
- msprobe/mindspore/code_mapping/cmd_parser.py +0 -40
- msprobe/mindspore/code_mapping/graph.py +0 -49
- msprobe/mindspore/code_mapping/graph_parser.py +0 -211
- msprobe/mindspore/code_mapping/main.py +0 -24
- msprobe/mindspore/code_mapping/processor.py +0 -34
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +0 -111
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -52
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +0 -257
- msprobe/mindspore/free_benchmark/common/config.py +0 -27
- msprobe/mindspore/free_benchmark/common/handler_params.py +0 -31
- msprobe/mindspore/free_benchmark/common/utils.py +0 -100
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -638
- msprobe/mindspore/free_benchmark/handler/base_handler.py +0 -105
- msprobe/mindspore/free_benchmark/handler/check_handler.py +0 -55
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +0 -51
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +0 -36
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +0 -82
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +0 -45
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +0 -78
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +0 -77
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +0 -56
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +0 -27
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +0 -46
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +0 -51
- msprobe/mindspore/grad_probe/global_context.py +0 -127
- msprobe/mindspore/grad_probe/grad_analyzer.py +0 -260
- msprobe/mindspore/grad_probe/grad_monitor.py +0 -42
- msprobe/mindspore/grad_probe/grad_stat_csv.py +0 -161
- msprobe/mindspore/grad_probe/hook.py +0 -115
- msprobe/mindspore/grad_probe/utils.py +0 -43
- msprobe/mindspore/mindtorch/__init__.py +0 -18
- msprobe/mindspore/ms_config.py +0 -153
- msprobe/mindspore/task_handler_factory.py +0 -44
- msprobe/nan_analyze/__init__.py +0 -14
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +0 -9
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +0 -480
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +0 -567
- msprobe/pytorch/debugger/precision_debugger.py +0 -181
- msprobe/pytorch/free_benchmark/__init__.py +0 -23
- msprobe/pytorch/free_benchmark/common/constant.py +0 -85
- msprobe/pytorch/free_benchmark/common/counter.py +0 -87
- msprobe/pytorch/free_benchmark/common/enums.py +0 -80
- msprobe/pytorch/free_benchmark/common/params.py +0 -152
- msprobe/pytorch/free_benchmark/common/utils.py +0 -143
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -215
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +0 -121
- msprobe/pytorch/free_benchmark/main.py +0 -123
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +0 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +0 -56
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +0 -107
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +0 -121
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +0 -89
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +0 -87
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +0 -43
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +0 -60
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +0 -34
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +0 -252
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +0 -54
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +0 -40
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -45
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -181
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +0 -108
- msprobe/pytorch/grad_probe/grad_stat_csv.py +0 -160
- msprobe/pytorch/hook_module/__init__.py +0 -16
- msprobe/pytorch/hook_module/wrap_aten.py +0 -111
- msprobe/pytorch/online_dispatch/__init__.py +0 -19
- msprobe/pytorch/online_dispatch/compare.py +0 -224
- msprobe/pytorch/online_dispatch/dispatch.py +0 -332
- msprobe/pytorch/online_dispatch/dump_compare.py +0 -179
- msprobe/pytorch/online_dispatch/single_compare.py +0 -412
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +0 -58
- msprobe/pytorch/online_dispatch/utils.py +0 -158
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +0 -31
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +0 -253
- msprobe/pytorch/parse_tool/lib/config.py +0 -50
- msprobe/pytorch/parse_tool/lib/file_desc.py +0 -45
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +0 -97
- msprobe/pytorch/parse_tool/lib/parse_exception.py +0 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +0 -161
- msprobe/pytorch/parse_tool/lib/utils.py +0 -299
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -85
- msprobe/pytorch/pt_config.py +0 -299
- /msprobe/core/{grad_probe → dump}/__init__.py +0 -0
- /msprobe/{mindspore/code_mapping → core/dump/api_dump}/__init__.py +0 -0
- /msprobe/{mindspore/debugger → core/dump/data_dump}/__init__.py +0 -0
- /msprobe/{mindspore/exception_dump → core/dump/data_dump/data_processor}/__init__.py +0 -0
- /msprobe/{mindspore/free_benchmark → core/dump/debugger}/__init__.py +0 -0
- /msprobe/{mindspore/free_benchmark/common → core/dump/kernel_dump}/__init__.py +0 -0
- /msprobe/mindspore/{free_benchmark/handler → dump/debugger}/__init__.py +0 -0
- /msprobe/mindspore/{grad_probe → dump/dump_processor}/__init__.py +0 -0
- /msprobe/mindspore/{overflow_check → dump/exception_dump}/__init__.py +0 -0
- /msprobe/mindspore/{mindtorch → dump/mindtorch}/mindtorch_adaptor.py +0 -0
- /msprobe/{pytorch/api_accuracy_checker/run_ut → mindspore/dump/overflow_check}/__init__.py +0 -0
- /msprobe/{pytorch/debugger → mindspore/monitor}/__init__.py +0 -0
- /msprobe/{pytorch/free_benchmark/common → msaccucmp}/__init__.py +0 -0
- /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/.keep +0 -0
- /msprobe/pytorch/{free_benchmark/perturbed_layers → api_accuracy_checker/acc_check}/__init__.py +0 -0
- /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/torch_ut_setting.json +0 -0
- /msprobe/pytorch/{free_benchmark/perturbed_layers/npu → dump/api_dump}/__init__.py +0 -0
- /msprobe/pytorch/{hook_module → dump/api_dump}/support_wrap_ops.yaml +0 -0
- /msprobe/pytorch/{free_benchmark/result_handlers → dump/debugger}/__init__.py +0 -0
|
@@ -1,107 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
19
|
-
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
21
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
22
|
-
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
23
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
24
|
-
NpuBaseLayer,
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class AddNoiseLayer(NpuBaseLayer):
|
|
29
|
-
|
|
30
|
-
@recursion_depth_decorator("FreeBenchmark: AddNoiseLayer.add_noise")
|
|
31
|
-
def add_noise(self, tensor_obj):
|
|
32
|
-
if isinstance(tensor_obj, torch.Tensor):
|
|
33
|
-
self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
|
|
34
|
-
tensor_obj.dtype
|
|
35
|
-
)
|
|
36
|
-
if not self.pre_check(tensor_obj):
|
|
37
|
-
return tensor_obj
|
|
38
|
-
noise = self._get_noise(tensor_obj)
|
|
39
|
-
result = TorchC.where(
|
|
40
|
-
TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5),
|
|
41
|
-
TorchC.add(noise, tensor_obj),
|
|
42
|
-
tensor_obj,
|
|
43
|
-
).to(tensor_obj.dtype)
|
|
44
|
-
self.is_added = True
|
|
45
|
-
return result
|
|
46
|
-
if isinstance(tensor_obj, dict):
|
|
47
|
-
return {key: self.add_noise(value) for key, value in tensor_obj.items()}
|
|
48
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
49
|
-
return type(tensor_obj)([self.add_noise(value) for value in tensor_obj])
|
|
50
|
-
return tensor_obj
|
|
51
|
-
|
|
52
|
-
def handle(self, params: DataParams):
|
|
53
|
-
"""
|
|
54
|
-
对输入添加扰动并返回
|
|
55
|
-
"""
|
|
56
|
-
logger.info_on_rank_0(
|
|
57
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
58
|
-
f"{PerturbationMode.ADD_NOISE} of {self.api_name}."
|
|
59
|
-
)
|
|
60
|
-
params.perturbed_value = self.add_noise(params.args[params.valid_input_index])
|
|
61
|
-
return self.perturbed_result(params)
|
|
62
|
-
|
|
63
|
-
def _get_noise(self, tensor_obj):
|
|
64
|
-
dtype = tensor_obj.dtype
|
|
65
|
-
device = str(tensor_obj.device)
|
|
66
|
-
noise = TorchC.full(
|
|
67
|
-
tensor_obj.shape,
|
|
68
|
-
self.perturbed_value,
|
|
69
|
-
device=device,
|
|
70
|
-
dtype=dtype,
|
|
71
|
-
)
|
|
72
|
-
return noise
|
|
73
|
-
|
|
74
|
-
def _check_details(self, tensor_obj):
|
|
75
|
-
"""
|
|
76
|
-
判断是否需要添加扰动
|
|
77
|
-
"""
|
|
78
|
-
if not self.perturbed_value:
|
|
79
|
-
logger.warning_on_rank_0(
|
|
80
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
81
|
-
f"dtype unsupported. Cancel perturbation."
|
|
82
|
-
)
|
|
83
|
-
return False
|
|
84
|
-
if tensor_obj.numel() == 0:
|
|
85
|
-
logger.warning_on_rank_0(
|
|
86
|
-
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
|
|
87
|
-
f" Cancel adding noise."
|
|
88
|
-
)
|
|
89
|
-
return False
|
|
90
|
-
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
91
|
-
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
92
|
-
)
|
|
93
|
-
try:
|
|
94
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
95
|
-
except Exception:
|
|
96
|
-
logger.warning_on_rank_0(
|
|
97
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
98
|
-
f"when calculating the maximum value, the tensor is changed to float32."
|
|
99
|
-
)
|
|
100
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
101
|
-
if max_val < abs_tol:
|
|
102
|
-
logger.warning_on_rank_0(
|
|
103
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
104
|
-
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
105
|
-
)
|
|
106
|
-
return False
|
|
107
|
-
return True
|
|
@@ -1,121 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
19
|
-
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
21
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
22
|
-
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
23
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
24
|
-
NpuBaseLayer,
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class BitNoiseLayer(NpuBaseLayer):
|
|
29
|
-
def __init__(self, api_name):
|
|
30
|
-
super().__init__(api_name)
|
|
31
|
-
self.bit_mode = TorchC.bitwise_xor
|
|
32
|
-
self.bit_tail: int = 1
|
|
33
|
-
self.bit_type = None
|
|
34
|
-
|
|
35
|
-
@recursion_depth_decorator("FreeBenchmark: BitNoiseLayer.add_bit_noise")
|
|
36
|
-
def add_bit_noise(self, tensor_obj):
|
|
37
|
-
"""
|
|
38
|
-
对输入添加噪声
|
|
39
|
-
"""
|
|
40
|
-
# finfo应该列入黑名单
|
|
41
|
-
|
|
42
|
-
if isinstance(tensor_obj, torch.Tensor):
|
|
43
|
-
self._set_perturbation_bit(tensor_obj)
|
|
44
|
-
if not self.pre_check(tensor_obj):
|
|
45
|
-
return tensor_obj
|
|
46
|
-
sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal
|
|
47
|
-
noise = TorchC.full(
|
|
48
|
-
tensor_obj.shape,
|
|
49
|
-
self.bit_tail,
|
|
50
|
-
device=tensor_obj.device,
|
|
51
|
-
dtype=self.bit_type,
|
|
52
|
-
)
|
|
53
|
-
result = tensor_obj.view(self.bit_type)
|
|
54
|
-
result = TorchC.where(
|
|
55
|
-
TorchC.gt(TorchC.abs(tensor_obj), sub_normal),
|
|
56
|
-
self.bit_mode(result, noise),
|
|
57
|
-
result,
|
|
58
|
-
).view(tensor_obj.dtype)
|
|
59
|
-
|
|
60
|
-
self.is_added = True
|
|
61
|
-
return result
|
|
62
|
-
if isinstance(tensor_obj, dict):
|
|
63
|
-
return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()}
|
|
64
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
65
|
-
return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
|
|
66
|
-
return tensor_obj
|
|
67
|
-
|
|
68
|
-
def handle(self, params: DataParams):
|
|
69
|
-
"""
|
|
70
|
-
对输入添加扰动并返回
|
|
71
|
-
"""
|
|
72
|
-
logger.info_on_rank_0(
|
|
73
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
74
|
-
f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
|
|
75
|
-
)
|
|
76
|
-
params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
|
|
77
|
-
return self.perturbed_result(params)
|
|
78
|
-
|
|
79
|
-
def _check_details(self, tensor_obj):
|
|
80
|
-
"""
|
|
81
|
-
判断是否需要添加扰动, bit翻转
|
|
82
|
-
"""
|
|
83
|
-
if not self.bit_type:
|
|
84
|
-
logger.warning_on_rank_0(
|
|
85
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
86
|
-
f"dtype unsupported. Cancel perturbation."
|
|
87
|
-
)
|
|
88
|
-
return False
|
|
89
|
-
if tensor_obj.numel() == 0:
|
|
90
|
-
logger.warning_on_rank_0(
|
|
91
|
-
f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
|
|
92
|
-
f" Cancel adding noise."
|
|
93
|
-
)
|
|
94
|
-
return False
|
|
95
|
-
abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
|
|
96
|
-
tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
|
|
97
|
-
)
|
|
98
|
-
try:
|
|
99
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
|
|
100
|
-
except Exception:
|
|
101
|
-
logger.warning_on_rank_0(
|
|
102
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
103
|
-
f"when calculate the maximum value, the tensor is changed to float32."
|
|
104
|
-
)
|
|
105
|
-
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
106
|
-
if max_val < abs_tol:
|
|
107
|
-
logger.warning_on_rank_0(
|
|
108
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
109
|
-
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
110
|
-
)
|
|
111
|
-
return False
|
|
112
|
-
return True
|
|
113
|
-
|
|
114
|
-
def _set_perturbation_bit(self, tensor_obj):
|
|
115
|
-
"""
|
|
116
|
-
根据不同浮点数确定不同位数扰动值
|
|
117
|
-
"""
|
|
118
|
-
bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype)
|
|
119
|
-
if bit_len_type:
|
|
120
|
-
self.bit_tail = 1
|
|
121
|
-
self.bit_type = bit_len_type
|
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
19
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
20
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
21
|
-
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
22
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
23
|
-
NpuBaseLayer,
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class ChangeValueLayer(NpuBaseLayer):
|
|
28
|
-
def __init__(self, api_name):
|
|
29
|
-
super().__init__(api_name)
|
|
30
|
-
self.head: int = 0
|
|
31
|
-
self.tail: int = -1
|
|
32
|
-
|
|
33
|
-
@recursion_depth_decorator("FreeBenchmark: ChangeValueLayer.change_value")
|
|
34
|
-
def change_value(self, tensor_obj):
|
|
35
|
-
"""
|
|
36
|
-
交换张量首尾
|
|
37
|
-
"""
|
|
38
|
-
if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj):
|
|
39
|
-
new_tensor = TorchC.clone(tensor_obj)
|
|
40
|
-
if new_tensor.ndim == 1:
|
|
41
|
-
temp_first = TorchC.clone(new_tensor[self.head])
|
|
42
|
-
temp_last = TorchC.clone(new_tensor[self.tail])
|
|
43
|
-
new_tensor[self.head] = temp_last
|
|
44
|
-
new_tensor[self.tail] = temp_first
|
|
45
|
-
else:
|
|
46
|
-
temp_first = TorchC.clone(new_tensor[self.head][self.head])
|
|
47
|
-
temp_last = TorchC.clone(new_tensor[self.tail][self.tail])
|
|
48
|
-
new_tensor[self.head][self.head] = temp_last
|
|
49
|
-
new_tensor[self.tail][self.tail] = temp_first
|
|
50
|
-
|
|
51
|
-
self.is_added = True
|
|
52
|
-
return new_tensor
|
|
53
|
-
if isinstance(tensor_obj, dict):
|
|
54
|
-
return {key: self.change_value(value) for key, value in tensor_obj.items()}
|
|
55
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
56
|
-
return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
|
|
57
|
-
return tensor_obj
|
|
58
|
-
|
|
59
|
-
def handle(self, params: DataParams):
|
|
60
|
-
"""
|
|
61
|
-
对输入添加扰动并返回
|
|
62
|
-
"""
|
|
63
|
-
logger.info_on_rank_0(
|
|
64
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
65
|
-
f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
|
|
66
|
-
)
|
|
67
|
-
params.perturbed_value = self.change_value(params.args[params.valid_input_index])
|
|
68
|
-
return self.perturbed_result(params)
|
|
69
|
-
|
|
70
|
-
def _check_details(self, tensor_obj):
|
|
71
|
-
"""
|
|
72
|
-
判断是否需要添加扰动, 首尾值交换
|
|
73
|
-
"""
|
|
74
|
-
# 对于维度大于1的张量、要求1维至少大于1且0维和1维至少一个长度大于2
|
|
75
|
-
if tensor_obj.ndim > 1:
|
|
76
|
-
if tensor_obj.size(1) == 0 or (tensor_obj.size(1) < 2 and tensor_obj.size(0) < 2):
|
|
77
|
-
logger.info_on_rank_0(
|
|
78
|
-
f"[msprobe] Free Benchmark: For {self.api_name} with ndim {tensor_obj.ndim}, "
|
|
79
|
-
f"at least one of 0-dimension or 1-dimension greater than 1. Cancel change value."
|
|
80
|
-
)
|
|
81
|
-
return False
|
|
82
|
-
# 不支持维度等于0的张量、对于维度等于1的张量、要求0维长度大于2
|
|
83
|
-
elif tensor_obj.dim() == 0 or tensor_obj.size(0) < 2:
|
|
84
|
-
logger.info_on_rank_0(
|
|
85
|
-
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
86
|
-
f"0-dimension must greater than 1. Cancel change value."
|
|
87
|
-
)
|
|
88
|
-
return False
|
|
89
|
-
return True
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.core.common.const import Const
|
|
18
|
-
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
19
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
20
|
-
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
21
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
22
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
23
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
24
|
-
NpuBaseLayer,
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class ImprovePrecisionLayer(NpuBaseLayer):
|
|
29
|
-
|
|
30
|
-
@recursion_depth_decorator(
|
|
31
|
-
"FreeBenchmark: ImprovePrecisionLayer.improve_tensor_precision"
|
|
32
|
-
)
|
|
33
|
-
def improve_tensor_precision(self, tensor_obj):
|
|
34
|
-
if (
|
|
35
|
-
isinstance(tensor_obj, torch.Tensor)
|
|
36
|
-
and torch.is_floating_point(tensor_obj)
|
|
37
|
-
and tensor_obj.dtype not in [torch.float32, torch.float64]
|
|
38
|
-
):
|
|
39
|
-
self._set_improve_values(tensor_obj)
|
|
40
|
-
tensor_obj = self._change_dtype(tensor_obj)
|
|
41
|
-
self.is_added = True
|
|
42
|
-
return tensor_obj
|
|
43
|
-
if isinstance(tensor_obj, dict):
|
|
44
|
-
return {
|
|
45
|
-
key: self.improve_tensor_precision(value)
|
|
46
|
-
for key, value in tensor_obj.items()
|
|
47
|
-
}
|
|
48
|
-
if isinstance(tensor_obj, (tuple, list)):
|
|
49
|
-
return type(tensor_obj)(
|
|
50
|
-
[self.improve_tensor_precision(value) for value in tensor_obj]
|
|
51
|
-
)
|
|
52
|
-
return tensor_obj
|
|
53
|
-
|
|
54
|
-
def handle(self, params: DataParams):
|
|
55
|
-
logger.info_on_rank_0(
|
|
56
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
57
|
-
f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
|
|
58
|
-
)
|
|
59
|
-
new_args = self.improve_tensor_precision(params.args)
|
|
60
|
-
if params.fuzz_stage == Const.BACKWARD:
|
|
61
|
-
new_kwargs = {}
|
|
62
|
-
else:
|
|
63
|
-
new_kwargs = self.improve_tensor_precision(params.kwargs)
|
|
64
|
-
# 如果输入中全为高精度、应跳过二次执行、减少多余显存引用
|
|
65
|
-
if not self.is_added:
|
|
66
|
-
return params.perturbed_result
|
|
67
|
-
if "inplace" in new_kwargs:
|
|
68
|
-
new_kwargs["inplace"] = False
|
|
69
|
-
params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
|
|
70
|
-
return params.perturbed_result
|
|
71
|
-
|
|
72
|
-
def _set_improve_values(self, inputs):
|
|
73
|
-
if inputs.dtype in [torch.float16, torch.bfloat16]:
|
|
74
|
-
self.perturbed_value = torch.float32
|
|
75
|
-
|
|
76
|
-
def _change_dtype(self, inputs):
|
|
77
|
-
if hasattr(inputs, CommonField.DEVICE):
|
|
78
|
-
device = inputs.device
|
|
79
|
-
if device is CommonField.META:
|
|
80
|
-
new_inputs = inputs.to(
|
|
81
|
-
device=CommonField.META, dtype=self.perturbed_value
|
|
82
|
-
)
|
|
83
|
-
else:
|
|
84
|
-
new_inputs = inputs.to(dtype=self.perturbed_value).to(device)
|
|
85
|
-
else:
|
|
86
|
-
new_inputs = inputs.to(dtype=self.perturbed_value)
|
|
87
|
-
return new_inputs
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
18
|
-
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
19
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
20
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
|
|
21
|
-
NpuBaseLayer,
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class NoChangeLayer(NpuBaseLayer):
|
|
26
|
-
|
|
27
|
-
def no_change(self, tensor_obj):
|
|
28
|
-
"""
|
|
29
|
-
不对输入做任何改变、直接二次执行
|
|
30
|
-
"""
|
|
31
|
-
self.is_added = True
|
|
32
|
-
return tensor_obj
|
|
33
|
-
|
|
34
|
-
def handle(self, params: DataParams):
|
|
35
|
-
"""
|
|
36
|
-
对输入添加扰动并返回
|
|
37
|
-
"""
|
|
38
|
-
logger.info_on_rank_0(
|
|
39
|
-
f"[msprobe] Free benchmark: Perturbation is "
|
|
40
|
-
f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
|
|
41
|
-
)
|
|
42
|
-
params.perturbed_value = self.no_change(params.args[params.valid_input_index])
|
|
43
|
-
return self.perturbed_result(params)
|
|
@@ -1,60 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
from abc import abstractmethod
|
|
17
|
-
from typing import Any
|
|
18
|
-
|
|
19
|
-
import torch
|
|
20
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
21
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class NpuBaseLayer(BaseLayer):
|
|
25
|
-
def __init__(self, api_name: str) -> None:
|
|
26
|
-
super().__init__(api_name)
|
|
27
|
-
self.perturbed_value = None # 扰动的元素
|
|
28
|
-
self.is_added = False # 标记当前算子输入是否调整
|
|
29
|
-
|
|
30
|
-
@staticmethod
|
|
31
|
-
def perturbed_result(params: DataParams) -> Any:
|
|
32
|
-
args_front = params.args[: params.valid_input_index]
|
|
33
|
-
args_rear = params.args[params.valid_input_index + 1:]
|
|
34
|
-
# 此处会将有inplace属性的算子换为非inplace
|
|
35
|
-
if "inplace" in params.kwargs:
|
|
36
|
-
params.kwargs["inplace"] = False
|
|
37
|
-
params.perturbed_result = params.origin_func(
|
|
38
|
-
*args_front, params.perturbed_value, *args_rear, **params.kwargs
|
|
39
|
-
)
|
|
40
|
-
return params.perturbed_result
|
|
41
|
-
|
|
42
|
-
@abstractmethod
|
|
43
|
-
def handle(self, params: DataParams) -> Any:
|
|
44
|
-
pass
|
|
45
|
-
|
|
46
|
-
def pre_check(self, tensor_obj):
|
|
47
|
-
"""
|
|
48
|
-
检查张量是否符合标准(float类型且最大值大于对应精度最小值)
|
|
49
|
-
"""
|
|
50
|
-
# 只针对第一个满足要求的添加扰动
|
|
51
|
-
if self.is_added:
|
|
52
|
-
return False
|
|
53
|
-
if not torch.is_floating_point(tensor_obj):
|
|
54
|
-
return False
|
|
55
|
-
if not self._check_details(tensor_obj):
|
|
56
|
-
return False
|
|
57
|
-
return True
|
|
58
|
-
|
|
59
|
-
def _check_details(self, tensor_obj):
|
|
60
|
-
return True
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
18
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
19
|
-
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
20
|
-
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
21
|
-
from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class CpuLayer(BaseLayer):
|
|
25
|
-
|
|
26
|
-
def handle(self, params: DataParams):
|
|
27
|
-
|
|
28
|
-
logger.info_on_rank_0(
|
|
29
|
-
f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
|
|
30
|
-
)
|
|
31
|
-
new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
|
|
32
|
-
new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
|
|
33
|
-
params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
|
|
34
|
-
return params.perturbed_result
|