mindstudio-probe 8.3.3__py3-none-any.whl → 26.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.3.dist-info → mindstudio_probe-26.0.0a1.dist-info}/METADATA +26 -14
- mindstudio_probe-26.0.0a1.dist-info/RECORD +498 -0
- {mindstudio_probe-8.3.3.dist-info → mindstudio_probe-26.0.0a1.dist-info}/WHEEL +1 -1
- mindstudio_probe-26.0.0a1.dist-info/entry_points.txt +5 -0
- mindstudio_probe-26.0.0a1.dist-info/licenses/LICENSE +124 -0
- mindstudio_probe-26.0.0a1.dist-info/top_level.txt +2 -0
- msprobe/__init__.py +12 -13
- msprobe/config.json +9 -31
- msprobe/core/__init__.py +12 -11
- msprobe/core/acc_check/acc_check_cli.py +145 -0
- msprobe/core/common/const.py +97 -38
- msprobe/core/common/db_manager.py +133 -12
- msprobe/core/common/decorator.py +12 -11
- msprobe/core/common/exceptions.py +12 -11
- msprobe/core/common/file_utils.py +101 -25
- msprobe/core/common/framework_adapter.py +36 -25
- msprobe/core/common/global_lock.py +12 -11
- msprobe/core/common/inplace_op_checker.py +12 -11
- msprobe/core/common/log.py +22 -11
- msprobe/core/common/megatron_utils.py +566 -11
- msprobe/core/common/parallel_state.py +12 -11
- msprobe/core/common/runtime.py +12 -11
- msprobe/core/common/utils.py +41 -41
- msprobe/core/compare/acc_compare.py +361 -104
- msprobe/core/compare/atb_data_compare.py +422 -0
- msprobe/core/compare/auto_compare.py +134 -0
- msprobe/core/compare/check.py +14 -17
- msprobe/core/compare/compare_cli.py +72 -149
- msprobe/core/compare/config.py +12 -13
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +28 -15
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/analyzer.py +18 -18
- msprobe/core/compare/find_first/graph.py +12 -11
- msprobe/core/compare/find_first/utils.py +13 -12
- msprobe/core/compare/indicator_analysis/__init__.py +15 -0
- msprobe/core/compare/indicator_analysis/algorithm.py +363 -0
- msprobe/core/compare/indicator_analysis/api_data.py +141 -0
- msprobe/core/compare/indicator_analysis/calculator.py +181 -0
- msprobe/core/compare/indicator_analysis/utils.py +116 -0
- msprobe/core/compare/layer_mapping/__init__.py +12 -11
- msprobe/core/compare/layer_mapping/data_scope_parser.py +20 -11
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -13
- msprobe/core/compare/layer_mapping/postprocess_pass.py +13 -11
- msprobe/core/compare/merge_result/merge_result.py +12 -11
- msprobe/core/compare/merge_result/merge_result_cli.py +12 -11
- msprobe/core/compare/merge_result/utils.py +12 -11
- msprobe/core/compare/multiprocessing_compute.py +13 -14
- msprobe/core/compare/npy_compare.py +13 -11
- msprobe/core/compare/offline_data_compare.py +160 -0
- msprobe/core/compare/stats_diff_calc.py +39 -0
- msprobe/core/compare/torchair_acc_cmp.py +764 -0
- msprobe/core/compare/torchair_cmp_utils.py +338 -0
- msprobe/core/compare/utils.py +140 -49
- msprobe/core/config_check/__init__.py +12 -11
- msprobe/core/config_check/checkers/__init__.py +12 -11
- msprobe/core/config_check/checkers/base_checker.py +15 -14
- msprobe/core/config_check/checkers/dataset_checker.py +13 -12
- msprobe/core/config_check/checkers/env_args_checker.py +13 -12
- msprobe/core/config_check/checkers/hyperparameter_checker.py +16 -15
- msprobe/core/config_check/checkers/pip_checker.py +15 -15
- msprobe/core/config_check/checkers/random_checker.py +13 -12
- msprobe/core/config_check/checkers/weights_checker.py +14 -12
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +13 -17
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +13 -12
- msprobe/core/config_check/ckpt_compare/metrics.py +12 -11
- msprobe/core/config_check/config_check_cli.py +18 -17
- msprobe/core/config_check/config_checker.py +16 -14
- msprobe/core/config_check/resource/dependency.yaml +15 -12
- msprobe/core/config_check/resource/env.yaml +12 -11
- msprobe/core/config_check/utils/hyperparameter_parser.py +12 -11
- msprobe/core/config_check/utils/utils.py +12 -11
- msprobe/core/{data_dump → dump/api_dump}/api_registry.py +12 -11
- msprobe/core/{common_config.py → dump/common_config.py} +13 -24
- msprobe/core/dump/data_dump/data_collector.py +257 -0
- msprobe/core/{data_dump → dump/data_dump}/data_processor/base.py +45 -36
- msprobe/core/{data_dump → dump/data_dump}/data_processor/factory.py +33 -25
- msprobe/core/{data_dump → dump/data_dump}/data_processor/mindspore_processor.py +37 -113
- msprobe/core/{data_dump → dump/data_dump}/data_processor/pytorch_processor.py +364 -131
- msprobe/core/{data_dump → dump/data_dump}/json_writer.py +24 -31
- msprobe/core/{data_dump → dump/data_dump}/scope.py +12 -13
- msprobe/core/{debugger → dump/debugger}/precision_debugger.py +15 -23
- msprobe/core/dump/dump2db/db_utils.py +215 -0
- msprobe/core/dump/dump2db/dump2db.py +409 -0
- msprobe/core/{hook_manager.py → dump/hook_manager.py} +38 -87
- msprobe/core/dump/kernel_dump/kernel_config.py +34 -0
- msprobe/core/{service.py → dump/service.py} +43 -27
- msprobe/core/install_deps/install_deps.py +51 -0
- msprobe/core/monitor/anomaly_processor.py +13 -11
- msprobe/core/monitor/csv2db.py +73 -93
- msprobe/core/monitor/db_utils.py +140 -205
- msprobe/core/monitor/utils.py +18 -17
- msprobe/core/monitor_v2/__init__.py +20 -0
- msprobe/core/monitor_v2/base.py +83 -0
- msprobe/core/monitor_v2/cc.py +287 -0
- msprobe/core/monitor_v2/factory.py +81 -0
- msprobe/core/monitor_v2/module.py +201 -0
- msprobe/core/monitor_v2/optimizer.py +245 -0
- msprobe/core/monitor_v2/param.py +154 -0
- msprobe/core/monitor_v2/trainer.py +326 -0
- msprobe/core/monitor_v2/utils.py +122 -0
- msprobe/core/monitor_v2/weight_grad.py +419 -0
- msprobe/core/monitor_v2/writer.py +162 -0
- msprobe/core/overflow_check/abnormal_scene.py +12 -11
- msprobe/core/overflow_check/api_info.py +12 -11
- msprobe/core/overflow_check/checker.py +12 -11
- msprobe/core/overflow_check/filter.py +13 -11
- msprobe/core/overflow_check/level.py +12 -11
- msprobe/core/overflow_check/utils.py +12 -11
- msprobe/core/single_save/single_comparator.py +12 -11
- msprobe/core/single_save/single_saver.py +12 -11
- msprobe/infer/__init__.py +16 -0
- msprobe/infer/offline/__init__.py +16 -0
- msprobe/infer/offline/compare/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/adapter_cli/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/adapter_cli/args_adapter.py +46 -0
- msprobe/infer/offline/compare/msquickcmp/atc/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/atc/atc_utils.py +98 -0
- msprobe/infer/offline/compare/msquickcmp/cmp_process.py +328 -0
- msprobe/infer/offline/compare/msquickcmp/common/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/common/args_check.py +112 -0
- msprobe/infer/offline/compare/msquickcmp/common/convert.py +74 -0
- msprobe/infer/offline/compare/msquickcmp/common/dump_data.py +121 -0
- msprobe/infer/offline/compare/msquickcmp/common/dynamic_argument_bean.py +39 -0
- msprobe/infer/offline/compare/msquickcmp/common/utils.py +669 -0
- msprobe/infer/offline/compare/msquickcmp/config.ini +6 -0
- msprobe/infer/offline/compare/msquickcmp/dump/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/dump/args_adapter.py +50 -0
- msprobe/infer/offline/compare/msquickcmp/dump/dump_process.py +91 -0
- msprobe/infer/offline/compare/msquickcmp/install_aclruntime_aisbench.sh +180 -0
- msprobe/infer/offline/compare/msquickcmp/main.py +199 -0
- msprobe/infer/offline/compare/msquickcmp/net_compare/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/net_compare/net_compare.py +277 -0
- msprobe/infer/offline/compare/msquickcmp/npu/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/npu/npu_dump_data.py +558 -0
- msprobe/infer/offline/compare/msquickcmp/npu/om_parser.py +416 -0
- msprobe/infer/offline/compare/msquickcmp/onnx_model/__init__.py +16 -0
- msprobe/infer/offline/compare/msquickcmp/onnx_model/onnx_dump_data.py +374 -0
- msprobe/infer/utils/__init__.py +15 -0
- msprobe/infer/utils/acc_cmp.py +94 -0
- msprobe/infer/utils/check/__init__.py +37 -0
- msprobe/infer/utils/check/args_checker.py +35 -0
- msprobe/infer/utils/check/checker.py +227 -0
- msprobe/infer/utils/check/dict_checker.py +78 -0
- msprobe/infer/utils/check/func_wrapper.py +96 -0
- msprobe/infer/utils/check/list_checker.py +56 -0
- msprobe/infer/utils/check/number_checker.py +64 -0
- msprobe/infer/utils/check/obj_checker.py +41 -0
- msprobe/infer/utils/check/path_checker.py +249 -0
- msprobe/infer/utils/check/rule.py +126 -0
- msprobe/infer/utils/check/string_checker.py +66 -0
- msprobe/infer/utils/cmp_algorithm.py +261 -0
- msprobe/infer/utils/constants.py +112 -0
- msprobe/infer/utils/file_open_check.py +337 -0
- msprobe/infer/utils/util.py +177 -0
- msprobe/mindspore/__init__.py +14 -13
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +14 -13
- msprobe/mindspore/api_accuracy_checker/api_info.py +12 -11
- msprobe/mindspore/api_accuracy_checker/api_runner.py +12 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +12 -11
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +12 -11
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +12 -11
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +12 -11
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +15 -14
- msprobe/mindspore/api_accuracy_checker/compute_element.py +12 -11
- msprobe/mindspore/api_accuracy_checker/data_manager.py +13 -11
- msprobe/mindspore/api_accuracy_checker/main.py +12 -11
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +14 -12
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +13 -11
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +12 -11
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +12 -11
- msprobe/mindspore/api_accuracy_checker/utils.py +12 -11
- msprobe/mindspore/common/const.py +15 -74
- msprobe/mindspore/common/log.py +12 -11
- msprobe/mindspore/common/utils.py +30 -15
- msprobe/mindspore/compare/common_dir_compare.py +21 -23
- msprobe/mindspore/compare/distributed_compare.py +18 -16
- msprobe/mindspore/compare/ms_compare.py +14 -14
- msprobe/mindspore/compare/ms_graph_compare.py +26 -20
- msprobe/mindspore/compare/utils.py +14 -12
- msprobe/mindspore/{cell_processor.py → dump/cell_processor.py} +15 -14
- msprobe/mindspore/{debugger → dump/debugger}/debugger_config.py +12 -30
- msprobe/mindspore/{debugger → dump/debugger}/precision_debugger.py +43 -45
- msprobe/mindspore/dump/{cell_dump_process.py → dump_processor/cell_dump_process.py} +31 -17
- msprobe/mindspore/dump/{cell_dump_with_insert_gradient.py → dump_processor/cell_dump_with_insert_gradient.py} +18 -14
- msprobe/mindspore/dump/{dump_tool_factory.py → dump_processor/dump_tool_factory.py} +16 -15
- msprobe/mindspore/dump/{graph_mode_cell_dump.py → dump_processor/graph_mode_cell_dump.py} +16 -15
- msprobe/mindspore/dump/{graph_tensor_dump.py → dump_processor/graph_tensor_dump.py} +134 -133
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/api_register.py +15 -14
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/hook_cell.py +12 -11
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/ms_hook_manager.py +47 -20
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/primitive_hooks.py +14 -13
- msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/support_wrap_ops.yaml +13 -11
- msprobe/mindspore/dump/{jit_dump.py → dump_processor/jit_dump.py} +14 -13
- msprobe/mindspore/dump/{kernel_graph_dump.py → dump_processor/kernel_graph_dump.py} +13 -12
- msprobe/mindspore/dump/{kernel_kbyk_dump.py → dump_processor/kernel_kbyk_dump.py} +13 -12
- msprobe/mindspore/{exception_dump → dump/exception_dump}/exception_dump_tool_factory.py +14 -13
- msprobe/mindspore/{exception_dump → dump/exception_dump}/kernel_graph_exception_dump.py +13 -12
- msprobe/mindspore/{mindspore_service.py → dump/mindspore_service.py} +18 -17
- msprobe/mindspore/dump/mindtorch/__init__.py +19 -0
- msprobe/mindspore/dump/ms_config.py +105 -0
- msprobe/mindspore/{overflow_check → dump/overflow_check}/kernel_graph_overflow_check.py +13 -12
- msprobe/mindspore/{overflow_check → dump/overflow_check}/overflow_check_tool_factory.py +14 -13
- msprobe/mindspore/dump/task_handler_factory.py +43 -0
- msprobe/mindspore/monitor/common_func.py +12 -11
- msprobe/mindspore/monitor/data_writers.py +12 -11
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +93 -39
- msprobe/mindspore/monitor/features.py +12 -11
- msprobe/mindspore/monitor/module_hook.py +19 -22
- msprobe/mindspore/monitor/optimizer_collect.py +29 -25
- msprobe/mindspore/monitor/utils.py +13 -11
- msprobe/msaccucmp/advisor/__init__.py +16 -0
- msprobe/msaccucmp/advisor/advisor_const.py +65 -0
- msprobe/msaccucmp/advisor/advisor_result.py +73 -0
- msprobe/msaccucmp/advisor/compare_advisor.py +99 -0
- msprobe/msaccucmp/advisor/input_advisor.py +66 -0
- msprobe/msaccucmp/advisor/node_advisor.py +68 -0
- msprobe/msaccucmp/advisor/overflow_advisor.py +58 -0
- msprobe/msaccucmp/algorithm_manager/__init__.py +16 -0
- msprobe/msaccucmp/algorithm_manager/algorithm_manager.py +464 -0
- msprobe/msaccucmp/algorithm_manager/algorithm_parameter.py +42 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_AccumulatedRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_CosineSimilarity.py +58 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_KullbackLeiblerDivergence.py +84 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxAbsoluteError.py +41 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanAbsoluteError.py +41 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanRelativeError.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RelativeEuclideanDistance.py +46 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RootMeanSquareError.py +40 -0
- msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_StandardDeviation.py +47 -0
- msprobe/msaccucmp/cmp_utils/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/common.py +113 -0
- msprobe/msaccucmp/cmp_utils/constant/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/constant/compare_error.py +81 -0
- msprobe/msaccucmp/cmp_utils/constant/const_manager.py +530 -0
- msprobe/msaccucmp/cmp_utils/file_utils.py +497 -0
- msprobe/msaccucmp/cmp_utils/log.py +257 -0
- msprobe/msaccucmp/cmp_utils/multi_process/__init__.py +16 -0
- msprobe/msaccucmp/cmp_utils/multi_process/multi_convert_process.py +140 -0
- msprobe/msaccucmp/cmp_utils/multi_process/progress.py +78 -0
- msprobe/msaccucmp/cmp_utils/path_check.py +274 -0
- msprobe/msaccucmp/cmp_utils/reg_manager.py +98 -0
- msprobe/msaccucmp/cmp_utils/tlv_parse.py +279 -0
- msprobe/msaccucmp/cmp_utils/utils.py +356 -0
- msprobe/msaccucmp/cmp_utils/utils_type.py +63 -0
- msprobe/msaccucmp/compare_vector.py +48 -0
- msprobe/msaccucmp/conversion/__init__.py +16 -0
- msprobe/msaccucmp/conversion/data_conversion.py +277 -0
- msprobe/msaccucmp/conversion/dtype_conversion.py +99 -0
- msprobe/msaccucmp/conversion/shape_format_conversion.py +477 -0
- msprobe/msaccucmp/conversion/tensor_conversion.py +369 -0
- msprobe/msaccucmp/dump_data_conversion.py +46 -0
- msprobe/msaccucmp/dump_parse/__init__.py +16 -0
- msprobe/msaccucmp/dump_parse/big_dump_data.py +317 -0
- msprobe/msaccucmp/dump_parse/dump.py +423 -0
- msprobe/msaccucmp/dump_parse/dump_data_object.py +322 -0
- msprobe/msaccucmp/dump_parse/dump_data_parser.py +436 -0
- msprobe/msaccucmp/dump_parse/dump_utils.py +246 -0
- msprobe/msaccucmp/dump_parse/ffts_parser.py +137 -0
- msprobe/msaccucmp/dump_parse/mapping.py +62 -0
- msprobe/msaccucmp/dump_parse/nano_dump_data.py +392 -0
- msprobe/msaccucmp/dump_parse/proto_dump_data.py +308 -0
- msprobe/msaccucmp/dump_parser.py +90 -0
- msprobe/msaccucmp/format_manager/__init__.py +16 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NCHW.py +53 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_ND.py +52 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NHWC.py +53 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_HWCN.py +47 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_NCHW.py +47 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_FRACTAL_Z.py +89 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NCHW.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NHWC.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_HWCN.py +43 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NCHW.py +48 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NHWC.py +43 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_FRACTAL_Z.py +87 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_NHWC.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_NCDHW.py +48 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_ND.py +44 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_FRACTAL_Z.py +87 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_HWCN.py +37 -0
- msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_NCHW.py +37 -0
- msprobe/msaccucmp/format_manager/format_manager.py +307 -0
- msprobe/msaccucmp/inplace_layer_process.py +186 -0
- msprobe/msaccucmp/msaccucmp.py +532 -0
- msprobe/msaccucmp/mscmp_advisor.py +128 -0
- msprobe/msaccucmp/overflow/__init__.py +16 -0
- msprobe/msaccucmp/overflow/overflow_analyse.py +305 -0
- msprobe/msaccucmp/overflow/overflow_detection.py +143 -0
- msprobe/msaccucmp/pytorch_cmp/__init__.py +16 -0
- msprobe/msaccucmp/pytorch_cmp/compare_pytorch.py +389 -0
- msprobe/msaccucmp/pytorch_cmp/hdf5_parser.py +377 -0
- msprobe/msaccucmp/pytorch_cmp/pytorch_dump_data.py +461 -0
- msprobe/msaccucmp/shape_conversion.py +41 -0
- msprobe/msaccucmp/vector_cmp/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/batch_compare.py +197 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/compare_detail.py +245 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/detail.py +182 -0
- msprobe/msaccucmp/vector_cmp/compare_detail/detail_writer.py +580 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_fusion_op.py +588 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_npu_vs_npu.py +339 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_result.py +326 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/compare_rule.py +156 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_op.py +204 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_rule_parser.py +635 -0
- msprobe/msaccucmp/vector_cmp/fusion_manager/quant_filter.py +187 -0
- msprobe/msaccucmp/vector_cmp/range_manager/__init__.py +16 -0
- msprobe/msaccucmp/vector_cmp/range_manager/range_manager.py +100 -0
- msprobe/msaccucmp/vector_cmp/range_manager/range_mode.py +94 -0
- msprobe/msaccucmp/vector_cmp/range_manager/select_mode.py +86 -0
- msprobe/msaccucmp/vector_cmp/vector_comparison.py +535 -0
- msprobe/msprobe.py +101 -130
- msprobe/overflow_check/__init__.py +15 -0
- msprobe/{nan_analyze → overflow_check}/analyzer.py +38 -27
- msprobe/{nan_analyze → overflow_check}/graph.py +28 -27
- msprobe/{nan_analyze → overflow_check}/utils.py +15 -14
- msprobe/pytorch/__init__.py +20 -14
- msprobe/pytorch/aclgraph_dump/__init__.py +45 -0
- msprobe/pytorch/aclgraph_dump/_meta.py +26 -0
- msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut.py → acc_check/acc_check.py} +50 -45
- msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut_utils.py → acc_check/acc_check_utils.py} +201 -30
- msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/data_generate.py +56 -16
- msprobe/pytorch/api_accuracy_checker/{run_ut/multi_run_ut.py → acc_check/multi_acc_check.py} +32 -47
- msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/run_overflow_check.py +19 -18
- msprobe/pytorch/api_accuracy_checker/common/config.py +22 -20
- msprobe/pytorch/api_accuracy_checker/common/utils.py +72 -13
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -11
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +23 -14
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +45 -32
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +12 -11
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +14 -12
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +14 -12
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +21 -19
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +14 -13
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +12 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +60 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +27 -16
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +13 -11
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +39 -18
- msprobe/pytorch/bench_functions/__init__.py +12 -11
- msprobe/pytorch/bench_functions/apply_adam.py +12 -11
- msprobe/pytorch/bench_functions/apply_adam_w.py +12 -11
- msprobe/pytorch/bench_functions/confusion_transpose.py +12 -11
- msprobe/pytorch/bench_functions/fast_gelu.py +12 -11
- msprobe/pytorch/bench_functions/group_norm_silu.py +12 -11
- msprobe/pytorch/bench_functions/layer_norm_eval.py +12 -11
- msprobe/pytorch/bench_functions/linear.py +12 -11
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -11
- msprobe/pytorch/bench_functions/mish.py +12 -11
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +12 -11
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +12 -11
- msprobe/pytorch/bench_functions/rms_norm.py +12 -11
- msprobe/pytorch/bench_functions/rotary_mul.py +12 -11
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +12 -11
- msprobe/pytorch/bench_functions/sort_v2.py +12 -11
- msprobe/pytorch/bench_functions/swiglu.py +12 -11
- msprobe/pytorch/common/__init__.py +12 -11
- msprobe/pytorch/common/log.py +12 -11
- msprobe/pytorch/common/parse_json.py +12 -11
- msprobe/pytorch/common/utils.py +52 -19
- msprobe/pytorch/compare/distributed_compare.py +13 -13
- msprobe/pytorch/compare/match.py +12 -11
- msprobe/pytorch/compare/pt_compare.py +14 -20
- msprobe/pytorch/compare/pt_diff_analyze.py +12 -11
- msprobe/pytorch/compare/utils.py +12 -11
- msprobe/pytorch/{hook_module → dump/api_dump}/api_register.py +18 -16
- msprobe/pytorch/{hook_module → dump/api_dump}/hook_module.py +14 -13
- msprobe/pytorch/{hook_module → dump/api_dump}/pt_hook_manager.py +68 -23
- msprobe/pytorch/{hook_module → dump/api_dump}/register_optimizer_hook.py +13 -11
- msprobe/pytorch/{hook_module → dump/api_dump}/script_wrapper.py +17 -14
- msprobe/pytorch/{hook_module → dump/api_dump}/utils.py +12 -11
- msprobe/pytorch/{debugger → dump/debugger}/debugger_config.py +23 -38
- msprobe/pytorch/dump/debugger/precision_debugger.py +130 -0
- msprobe/pytorch/{function_factory.py → dump/function_factory.py} +12 -11
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +17 -13
- msprobe/pytorch/dump/module_dump/module_dump.py +16 -15
- msprobe/pytorch/dump/module_dump/{module_processer.py → module_processor.py} +54 -42
- msprobe/pytorch/dump/pt_config.py +128 -0
- msprobe/pytorch/{pytorch_service.py → dump/pytorch_service.py} +22 -21
- msprobe/pytorch/monitor/csv2tb.py +13 -11
- msprobe/pytorch/monitor/data_writers.py +13 -11
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +13 -11
- msprobe/pytorch/monitor/features.py +12 -11
- msprobe/pytorch/monitor/module_hook.py +67 -59
- msprobe/pytorch/monitor/module_metric.py +13 -11
- msprobe/pytorch/monitor/optimizer_collect.py +37 -35
- msprobe/pytorch/monitor/utils.py +13 -11
- msprobe/pytorch/monitor/visualizer.py +12 -11
- msprobe/pytorch/torchair_dump/__init__.py +17 -0
- msprobe/pytorch/torchair_dump/torchair_dump.py +114 -0
- msprobe/scripts/atb/config_example.json +10 -0
- msprobe/scripts/atb/load_atb_probe.sh +101 -0
- msprobe/scripts/atb/unload_atb_probe.sh +27 -0
- msprobe/scripts/build_msaccucmp.sh +186 -0
- msprobe/scripts/conf/help.info +6 -0
- msprobe/scripts/conf/version.info +3 -0
- msprobe/scripts/run_script/common.sh +538 -0
- msprobe/scripts/run_script/main_msaccucmp.sh +232 -0
- msprobe/visualization/__init__.py +12 -11
- msprobe/visualization/builder/__init__.py +12 -11
- msprobe/visualization/builder/graph_builder.py +45 -30
- msprobe/visualization/builder/graph_merger.py +53 -32
- msprobe/visualization/builder/msprobe_adapter.py +34 -44
- msprobe/visualization/compare/__init__.py +12 -11
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +28 -113
- msprobe/visualization/db_utils.py +133 -22
- msprobe/visualization/graph/__init__.py +12 -11
- msprobe/visualization/graph/base_node.py +15 -27
- msprobe/visualization/graph/distributed_analyzer.py +97 -40
- msprobe/visualization/graph/graph.py +14 -16
- msprobe/visualization/graph/node_colors.py +34 -31
- msprobe/visualization/graph/node_op.py +12 -11
- msprobe/visualization/graph_service.py +580 -205
- msprobe/visualization/utils.py +278 -31
- tb_graph_ascend/secure_build.py +175 -0
- tb_graph_ascend/server/__init__.py +15 -0
- tb_graph_ascend/server/app/__init__.py +15 -0
- tb_graph_ascend/server/app/model/__init__.py +15 -0
- tb_graph_ascend/server/app/model/hierarchy.py +348 -0
- tb_graph_ascend/server/app/model/layout_hierarchy_model.py +69 -0
- tb_graph_ascend/server/app/model/match_nodes_model.py +573 -0
- tb_graph_ascend/server/app/repositories/__init__.py +15 -0
- tb_graph_ascend/server/app/repositories/graph_repo_base.py +32 -0
- tb_graph_ascend/server/app/repositories/graph_repo_db.py +879 -0
- tb_graph_ascend/server/app/repositories/graph_repo_vis.py +83 -0
- tb_graph_ascend/server/app/service/__init__.py +18 -0
- tb_graph_ascend/server/app/service/graph_service_base.py +158 -0
- tb_graph_ascend/server/app/service/graph_service_db.py +438 -0
- tb_graph_ascend/server/app/service/graph_service_factory.py +54 -0
- tb_graph_ascend/server/app/service/graph_service_vis.py +480 -0
- tb_graph_ascend/server/app/utils/__init__.py +15 -0
- tb_graph_ascend/server/app/utils/constant.py +80 -0
- tb_graph_ascend/server/app/utils/file_check_wrapper.py +46 -0
- tb_graph_ascend/server/app/utils/global_state.py +95 -0
- tb_graph_ascend/server/app/utils/graph_utils.py +661 -0
- tb_graph_ascend/server/app/utils/i18n.py +153 -0
- tb_graph_ascend/server/app/utils/request_method.py +46 -0
- tb_graph_ascend/server/app/views/__init__.py +15 -0
- tb_graph_ascend/server/app/views/graph_views.py +304 -0
- tb_graph_ascend/server/plugin.py +108 -0
- tb_graph_ascend/server/static/index.html +9250 -0
- tb_graph_ascend/server/static/index.js +21 -0
- tb_graph_ascend/setup.py +57 -0
- mindstudio_probe-8.3.3.dist-info/LICENSE +0 -201
- mindstudio_probe-8.3.3.dist-info/RECORD +0 -491
- mindstudio_probe-8.3.3.dist-info/entry_points.txt +0 -2
- mindstudio_probe-8.3.3.dist-info/top_level.txt +0 -1
- msprobe/CMakeLists.txt +0 -5
- msprobe/README.md +0 -203
- msprobe/core/advisor/advisor.py +0 -129
- msprobe/core/advisor/advisor_const.py +0 -58
- msprobe/core/advisor/advisor_result.py +0 -58
- msprobe/core/compare/find_first/data_processor.py +0 -35
- msprobe/core/compare/highlight.py +0 -390
- msprobe/core/data_dump/data_collector.py +0 -356
- msprobe/core/grad_probe/constant.py +0 -90
- msprobe/core/grad_probe/grad_compare.py +0 -187
- msprobe/core/grad_probe/utils.py +0 -105
- msprobe/core/kernel_dump/kernel_config.py +0 -33
- msprobe/docs/01.installation.md +0 -250
- msprobe/docs/02.config_introduction.md +0 -221
- msprobe/docs/03.config_examples.md +0 -281
- msprobe/docs/04.kernel_dump_PyTorch.md +0 -73
- msprobe/docs/05.data_dump_PyTorch.md +0 -518
- msprobe/docs/06.data_dump_MindSpore.md +0 -618
- msprobe/docs/07.accuracy_checker_PyTorch.md +0 -310
- msprobe/docs/09.accuracy_checker_MindSpore.md +0 -120
- msprobe/docs/10.accuracy_compare_PyTorch.md +0 -637
- msprobe/docs/11.accuracy_compare_MindSpore.md +0 -769
- msprobe/docs/12.overflow_check_PyTorch.md +0 -82
- msprobe/docs/13.overflow_check_MindSpore.md +0 -33
- msprobe/docs/14.data_parse_PyTorch.md +0 -282
- msprobe/docs/15.free_benchmarking_PyTorch.md +0 -169
- msprobe/docs/16.free_benchmarking_MindSpore.md +0 -159
- msprobe/docs/17.grad_probe.md +0 -205
- msprobe/docs/18.online_dispatch.md +0 -89
- msprobe/docs/19.monitor.md +0 -753
- msprobe/docs/20.monitor_performance_baseline.md +0 -52
- msprobe/docs/21.visualization_PyTorch.md +0 -519
- msprobe/docs/22.visualization_MindSpore.md +0 -515
- msprobe/docs/23.generate_operator_PyTorch.md +0 -107
- msprobe/docs/24.code_mapping_Mindspore.md +0 -29
- msprobe/docs/25.tool_function_introduction.md +0 -29
- msprobe/docs/26.data_dump_PyTorch_baseline.md +0 -48
- msprobe/docs/27.dump_json_instruction.md +0 -795
- msprobe/docs/28.debugger_save_instruction.md +0 -288
- msprobe/docs/28.kernel_dump_MindSpore.md +0 -69
- msprobe/docs/29.data_dump_MSAdapter.md +0 -235
- msprobe/docs/30.overflow_check_MSAdapter.md +0 -31
- msprobe/docs/31.config_check.md +0 -107
- msprobe/docs/32.ckpt_compare.md +0 -69
- msprobe/docs/33.generate_operator_MindSpore.md +0 -181
- msprobe/docs/34.RL_collect.md +0 -101
- msprobe/docs/35.nan_analyze.md +0 -73
- msprobe/docs/36.calculation_result_change.md +0 -75
- msprobe/docs/FAQ.md +0 -232
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +0 -146
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +0 -14
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +0 -33
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +0 -217
- msprobe/docs/img/BLOOM-7B_1.png +0 -0
- msprobe/docs/img/BLOOM-7B_2.png +0 -0
- msprobe/docs/img/BLOOM-7B_3.png +0 -0
- msprobe/docs/img/BLOOM-7B_4.png +0 -0
- msprobe/docs/img/GPT-3_1.png +0 -0
- msprobe/docs/img/GPT-3_2.png +0 -0
- msprobe/docs/img/GPT-3_3.png +0 -0
- msprobe/docs/img/GPT-3_4.png +0 -0
- msprobe/docs/img/GPT-3_5.png +0 -0
- msprobe/docs/img/GPT-3_6.png +0 -0
- msprobe/docs/img/GPT-3_7.png +0 -0
- msprobe/docs/img/GPT-3_8.png +0 -0
- msprobe/docs/img/YOLOV5S_1.png +0 -0
- msprobe/docs/img/YOLOV5S_2.png +0 -0
- msprobe/docs/img/accuracy_checking_details.png +0 -0
- msprobe/docs/img/accuracy_checking_result.png +0 -0
- msprobe/docs/img/api_precision_compare_details.png +0 -0
- msprobe/docs/img/api_precision_compare_result.png +0 -0
- msprobe/docs/img/auto_analyze_log.png +0 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/compare_result_pkl.png +0 -0
- msprobe/docs/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/docs/img/cpu_info.png +0 -0
- msprobe/docs/img/free_benchmark.png +0 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/module_compare.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +0 -132
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +0 -59
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +0 -80
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +0 -330
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +0 -460
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +0 -2081
- msprobe/mindspore/code_mapping/bind.py +0 -283
- msprobe/mindspore/code_mapping/cmd_parser.py +0 -40
- msprobe/mindspore/code_mapping/graph.py +0 -49
- msprobe/mindspore/code_mapping/graph_parser.py +0 -211
- msprobe/mindspore/code_mapping/main.py +0 -24
- msprobe/mindspore/code_mapping/processor.py +0 -34
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +0 -111
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -52
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +0 -257
- msprobe/mindspore/free_benchmark/common/config.py +0 -27
- msprobe/mindspore/free_benchmark/common/handler_params.py +0 -31
- msprobe/mindspore/free_benchmark/common/utils.py +0 -100
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -638
- msprobe/mindspore/free_benchmark/handler/base_handler.py +0 -105
- msprobe/mindspore/free_benchmark/handler/check_handler.py +0 -55
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +0 -51
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +0 -36
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +0 -82
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +0 -45
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +0 -78
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +0 -77
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +0 -56
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +0 -27
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +0 -46
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +0 -51
- msprobe/mindspore/grad_probe/global_context.py +0 -127
- msprobe/mindspore/grad_probe/grad_analyzer.py +0 -260
- msprobe/mindspore/grad_probe/grad_monitor.py +0 -42
- msprobe/mindspore/grad_probe/grad_stat_csv.py +0 -161
- msprobe/mindspore/grad_probe/hook.py +0 -115
- msprobe/mindspore/grad_probe/utils.py +0 -43
- msprobe/mindspore/mindtorch/__init__.py +0 -18
- msprobe/mindspore/ms_config.py +0 -153
- msprobe/mindspore/task_handler_factory.py +0 -44
- msprobe/nan_analyze/__init__.py +0 -14
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +0 -9
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +0 -480
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +0 -567
- msprobe/pytorch/debugger/precision_debugger.py +0 -181
- msprobe/pytorch/free_benchmark/__init__.py +0 -23
- msprobe/pytorch/free_benchmark/common/constant.py +0 -85
- msprobe/pytorch/free_benchmark/common/counter.py +0 -87
- msprobe/pytorch/free_benchmark/common/enums.py +0 -80
- msprobe/pytorch/free_benchmark/common/params.py +0 -152
- msprobe/pytorch/free_benchmark/common/utils.py +0 -143
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -215
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +0 -121
- msprobe/pytorch/free_benchmark/main.py +0 -123
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +0 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +0 -56
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +0 -107
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +0 -121
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +0 -89
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +0 -87
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +0 -43
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +0 -60
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +0 -34
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +0 -252
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +0 -54
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +0 -40
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -45
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -181
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +0 -108
- msprobe/pytorch/grad_probe/grad_stat_csv.py +0 -160
- msprobe/pytorch/hook_module/__init__.py +0 -16
- msprobe/pytorch/hook_module/wrap_aten.py +0 -111
- msprobe/pytorch/online_dispatch/__init__.py +0 -19
- msprobe/pytorch/online_dispatch/compare.py +0 -224
- msprobe/pytorch/online_dispatch/dispatch.py +0 -332
- msprobe/pytorch/online_dispatch/dump_compare.py +0 -179
- msprobe/pytorch/online_dispatch/single_compare.py +0 -412
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +0 -58
- msprobe/pytorch/online_dispatch/utils.py +0 -158
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +0 -31
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +0 -253
- msprobe/pytorch/parse_tool/lib/config.py +0 -50
- msprobe/pytorch/parse_tool/lib/file_desc.py +0 -45
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +0 -97
- msprobe/pytorch/parse_tool/lib/parse_exception.py +0 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +0 -161
- msprobe/pytorch/parse_tool/lib/utils.py +0 -299
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -85
- msprobe/pytorch/pt_config.py +0 -299
- /msprobe/core/{grad_probe → dump}/__init__.py +0 -0
- /msprobe/{mindspore/code_mapping → core/dump/api_dump}/__init__.py +0 -0
- /msprobe/{mindspore/debugger → core/dump/data_dump}/__init__.py +0 -0
- /msprobe/{mindspore/exception_dump → core/dump/data_dump/data_processor}/__init__.py +0 -0
- /msprobe/{mindspore/free_benchmark → core/dump/debugger}/__init__.py +0 -0
- /msprobe/{mindspore/free_benchmark/common → core/dump/kernel_dump}/__init__.py +0 -0
- /msprobe/mindspore/{free_benchmark/handler → dump/debugger}/__init__.py +0 -0
- /msprobe/mindspore/{grad_probe → dump/dump_processor}/__init__.py +0 -0
- /msprobe/mindspore/{overflow_check → dump/exception_dump}/__init__.py +0 -0
- /msprobe/mindspore/{mindtorch → dump/mindtorch}/mindtorch_adaptor.py +0 -0
- /msprobe/{pytorch/api_accuracy_checker/run_ut → mindspore/dump/overflow_check}/__init__.py +0 -0
- /msprobe/{pytorch/debugger → mindspore/monitor}/__init__.py +0 -0
- /msprobe/{pytorch/free_benchmark/common → msaccucmp}/__init__.py +0 -0
- /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/.keep +0 -0
- /msprobe/pytorch/{free_benchmark/perturbed_layers → api_accuracy_checker/acc_check}/__init__.py +0 -0
- /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/torch_ut_setting.json +0 -0
- /msprobe/pytorch/{free_benchmark/perturbed_layers/npu → dump/api_dump}/__init__.py +0 -0
- /msprobe/pytorch/{hook_module → dump/api_dump}/support_wrap_ops.yaml +0 -0
- /msprobe/pytorch/{free_benchmark/result_handlers → dump/debugger}/__init__.py +0 -0
|
@@ -0,0 +1,764 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# This file is part of the MindStudio project.
|
|
3
|
+
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
|
|
4
|
+
#
|
|
5
|
+
# MindStudio is licensed under Mulan PSL v2.
|
|
6
|
+
# You can use this software according to the terms and conditions of the Mulan PSL v2.
|
|
7
|
+
# You may obtain a copy of Mulan PSL v2 at:
|
|
8
|
+
#
|
|
9
|
+
# http://license.coscl.org.cn/MulanPSL2
|
|
10
|
+
#
|
|
11
|
+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
|
12
|
+
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
|
13
|
+
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
|
14
|
+
# See the Mulan PSL v2 for more details.
|
|
15
|
+
# -------------------------------------------------------------------------
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from multiprocessing import Pool, cpu_count
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.const import FileCheckConst
|
|
25
|
+
from msprobe.core.common.file_utils import FileChecker, FileOpen, check_file_or_directory_path
|
|
26
|
+
from msprobe.core.common.log import logger
|
|
27
|
+
from msprobe.core.common.utils import CompareException
|
|
28
|
+
from msprobe.core.compare.torchair_cmp_utils import BasicDataInfo, fill_row_data, save_compare_result_to_csv
|
|
29
|
+
from msprobe.infer.utils.acc_cmp import parse_torchair_dump_data, set_msaccucmp_path_from_cann
|
|
30
|
+
|
|
31
|
+
GE_GRAPH_FILE_PREFIX = 'dynamo_original_graph_'
|
|
32
|
+
GE_DUMP_TIME_PATTERN = 'YYYYMMDDHHMMSS'
|
|
33
|
+
FUSION_OP_TYPE = 'AutomaticBufferFusionOp'
|
|
34
|
+
DUMP_FILE_FILTER_SUFIX = ['.txt', '.npy', '.bin']
|
|
35
|
+
MAX_TOKEN_LEN = 12
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_rank_id_from_torchair_data(dir_name: str):
|
|
39
|
+
rank_id = -1
|
|
40
|
+
rank_index = dir_name.rfind('rank')
|
|
41
|
+
if dir_name.startswith('worldsize') and rank_index != -1 and str.isdigit(dir_name[rank_index + 4:]):
|
|
42
|
+
rank_id = int(dir_name[rank_index + 4:])
|
|
43
|
+
return rank_id
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_torchair_ge_graph_path(my_path, rank=-1):
|
|
47
|
+
if not os.path.isdir(my_path):
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
ge_graph_files = []
|
|
51
|
+
my_path_depth = len(my_path.split(os.sep))
|
|
52
|
+
timestamp_pattern = re.compile(r"(\d+)")
|
|
53
|
+
for cur_path, _, file_names in os.walk(my_path):
|
|
54
|
+
for file_name in file_names:
|
|
55
|
+
if rank > -1 and f'rank_{rank}_' not in file_name:
|
|
56
|
+
continue
|
|
57
|
+
if file_name.startswith(GE_GRAPH_FILE_PREFIX) and file_name.endswith(".txt"):
|
|
58
|
+
match = timestamp_pattern.search(file_name)
|
|
59
|
+
if match:
|
|
60
|
+
full_path = os.path.join(cur_path, file_name)
|
|
61
|
+
timestamp = int(match.group(1))
|
|
62
|
+
ge_graph_files.append((full_path, timestamp))
|
|
63
|
+
|
|
64
|
+
cur_depth = len(cur_path.split(os.sep)) - my_path_depth
|
|
65
|
+
if cur_depth > 5: # Avoid going too deep
|
|
66
|
+
break
|
|
67
|
+
|
|
68
|
+
if ge_graph_files:
|
|
69
|
+
sorted_ge_graph_files = [file for file, timestamp in sorted(ge_graph_files, key=lambda x: x[1])]
|
|
70
|
+
return sorted_ge_graph_files
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_unique_key(cur_dict, cur_key):
|
|
75
|
+
split_sign, original_cur_key, cur_key_id = "#", cur_key, 0
|
|
76
|
+
while cur_key in cur_dict:
|
|
77
|
+
cur_key_id += 1
|
|
78
|
+
cur_key = f"{original_cur_key}{split_sign}{cur_key_id}"
|
|
79
|
+
return cur_key
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def parse_pbtxt_to_dict(pbtxt_path):
|
|
83
|
+
check_file_or_directory_path(pbtxt_path)
|
|
84
|
+
with FileOpen(pbtxt_path, "r") as ff:
|
|
85
|
+
contents = ff.read()
|
|
86
|
+
|
|
87
|
+
result, cur_dict, superior_dicts, brackets_depth = [], {}, [], 0
|
|
88
|
+
for cur_line in contents.split("\n"):
|
|
89
|
+
cur_line = cur_line.strip()
|
|
90
|
+
if len(cur_line) == 0:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
if " {" in cur_line:
|
|
94
|
+
if brackets_depth == 0:
|
|
95
|
+
cur_dict = {}
|
|
96
|
+
superior_dicts = []
|
|
97
|
+
result.append(cur_dict)
|
|
98
|
+
cur_key = cur_line.split(" {")[0]
|
|
99
|
+
cur_key = get_unique_key(cur_dict, cur_key)
|
|
100
|
+
cur_dict[cur_key] = {}
|
|
101
|
+
if len(superior_dicts) > brackets_depth:
|
|
102
|
+
superior_dicts[brackets_depth] = cur_dict
|
|
103
|
+
else:
|
|
104
|
+
superior_dicts.append(cur_dict)
|
|
105
|
+
cur_dict = cur_dict[cur_key]
|
|
106
|
+
brackets_depth += 1
|
|
107
|
+
elif ": " in cur_line:
|
|
108
|
+
cur_key, cur_value = cur_line.split(": ")
|
|
109
|
+
cur_key = get_unique_key(cur_dict, cur_key)
|
|
110
|
+
cur_value = cur_value[1:-1] if cur_value.startswith('"') and cur_value.endswith('"') else cur_value
|
|
111
|
+
cur_dict[cur_key] = cur_value
|
|
112
|
+
elif "}" in cur_line:
|
|
113
|
+
brackets_depth -= 1
|
|
114
|
+
cur_dict = superior_dicts[brackets_depth]
|
|
115
|
+
return result
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def judge_single_or_multi_device(path):
|
|
119
|
+
def is_time_directory(dir_name):
|
|
120
|
+
if not os.path.isdir(os.path.join(path, dir_name)):
|
|
121
|
+
return False
|
|
122
|
+
return len(dir_name) == len(GE_DUMP_TIME_PATTERN) and str.isdigit(dir_name)
|
|
123
|
+
|
|
124
|
+
# 获取指定目录下所有文件和文件夹
|
|
125
|
+
entries = os.listdir(path)
|
|
126
|
+
# 过滤出文件夹
|
|
127
|
+
subdirs = [entry for entry in entries if os.path.isdir(os.path.join(path, entry))]
|
|
128
|
+
if len(subdirs) > 1:
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
time_dirs = [os.path.join(path, entry) for entry in entries if is_time_directory(entry)]
|
|
132
|
+
if time_dirs:
|
|
133
|
+
entries = os.listdir(time_dirs[0])
|
|
134
|
+
subdirs = [entry for entry in entries if os.path.isdir(os.path.join(time_dirs[0], entry))]
|
|
135
|
+
return len(subdirs) > 1
|
|
136
|
+
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _has_rank_directory(dump_dir):
|
|
141
|
+
if not os.path.isdir(dump_dir):
|
|
142
|
+
return False
|
|
143
|
+
for entry in os.listdir(dump_dir):
|
|
144
|
+
if os.path.isdir(os.path.join(dump_dir, entry)) and get_rank_id_from_torchair_data(entry) != -1:
|
|
145
|
+
return True
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _validate_read_path(path):
|
|
150
|
+
path_type = FileCheckConst.DIR if os.path.isdir(path) else FileCheckConst.FILE
|
|
151
|
+
FileChecker(path, path_type, ability=FileCheckConst.READ_ABLE).common_check()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def gather_data_with_token_id_fx(data_path, token_dirs, rank_info_existed=False):
|
|
155
|
+
for cur_path, dirs, _ in os.walk(data_path):
|
|
156
|
+
if len(dirs) == 0:
|
|
157
|
+
continue
|
|
158
|
+
if all([len(ii) < MAX_TOKEN_LEN and str.isdigit(ii) for ii in dirs]):
|
|
159
|
+
dirs = sorted(dirs, key=lambda xx: int(xx))
|
|
160
|
+
token_dirs = [os.path.join(cur_path, dir_name) for dir_name in dirs]
|
|
161
|
+
break
|
|
162
|
+
|
|
163
|
+
if len(token_dirs) == 0:
|
|
164
|
+
token_dirs.append(data_path) # Just use data_path if found no token like dirs
|
|
165
|
+
|
|
166
|
+
gathered_files_list = []
|
|
167
|
+
|
|
168
|
+
if rank_info_existed:
|
|
169
|
+
gathered_files = {}
|
|
170
|
+
for token_dir in token_dirs:
|
|
171
|
+
cur_token_id = os.path.basename(token_dir)
|
|
172
|
+
cur_token_id = int(cur_token_id) + 1 if cur_token_id.isdigit() else 0
|
|
173
|
+
file_names = [os.path.join(token_dir, f) for f in os.listdir(token_dir) if f.endswith(".npy")]
|
|
174
|
+
gathered_files[cur_token_id] = file_names
|
|
175
|
+
gathered_files_list.append(gathered_files)
|
|
176
|
+
return gathered_files_list
|
|
177
|
+
|
|
178
|
+
dump_dirs = {}
|
|
179
|
+
for token_dir in token_dirs:
|
|
180
|
+
cur_token_id = os.path.basename(token_dir)
|
|
181
|
+
cur_token_id = int(cur_token_id) if cur_token_id.isdigit() else 0
|
|
182
|
+
dump_dirs[cur_token_id] = sorted(
|
|
183
|
+
[
|
|
184
|
+
os.path.join(token_dir, d)
|
|
185
|
+
for d in os.listdir(token_dir)
|
|
186
|
+
if os.path.isdir(os.path.join(token_dir, d))
|
|
187
|
+
],
|
|
188
|
+
key=lambda x: os.path.basename(x),
|
|
189
|
+
)
|
|
190
|
+
num_dumps = len(dump_dirs.get(1, None))
|
|
191
|
+
for i in range(num_dumps):
|
|
192
|
+
gathered_files = {}
|
|
193
|
+
for cur_token_id, dumps in dump_dirs.items():
|
|
194
|
+
dump_path = dumps[i]
|
|
195
|
+
file_names = [os.path.join(dump_path, f) for f in os.listdir(dump_path) if f.endswith(".npy")]
|
|
196
|
+
gathered_files[cur_token_id] = file_names
|
|
197
|
+
gathered_files_list.append(gathered_files)
|
|
198
|
+
return gathered_files_list
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def gather_data_with_token_id(data_path, fx=False, rank_info_existed=False):
|
|
202
|
+
token_dirs = []
|
|
203
|
+
# Detect the deepest dir level where sub dirs are all digits, and regard as tokens level.
|
|
204
|
+
if fx:
|
|
205
|
+
return gather_data_with_token_id_fx(data_path, token_dirs, rank_info_existed)
|
|
206
|
+
|
|
207
|
+
is_multi_device = False if rank_info_existed else judge_single_or_multi_device(data_path)
|
|
208
|
+
|
|
209
|
+
for cur_path, dirs, _ in sorted(os.walk(data_path), key=lambda x: x[0]):
|
|
210
|
+
if not dirs:
|
|
211
|
+
token_dirs.append(cur_path)
|
|
212
|
+
|
|
213
|
+
if len(token_dirs) == 0:
|
|
214
|
+
token_dirs.append(data_path) # Just use data_path if found no token like dirs
|
|
215
|
+
|
|
216
|
+
gathered_files_list = []
|
|
217
|
+
parent_dir_dict = {}
|
|
218
|
+
for token_dir in token_dirs:
|
|
219
|
+
if is_multi_device:
|
|
220
|
+
parts = token_dir
|
|
221
|
+
"""
|
|
222
|
+
token_dir格式如下:
|
|
223
|
+
/home/dump/dump_20241114_113410/0/graph_1_0/1/4/
|
|
224
|
+
dump路径+时间戳+device_id+子图名称+子图ID号+token_id
|
|
225
|
+
对于多卡场景,应取device_id下相同的子图进行比较,
|
|
226
|
+
此处parts=/home/dump/dump_20241114_113410/0
|
|
227
|
+
"""
|
|
228
|
+
for _ in range(3):
|
|
229
|
+
parts = os.path.dirname(parts)
|
|
230
|
+
parent_dir = os.path.basename(parts)
|
|
231
|
+
else:
|
|
232
|
+
parent_dir = os.path.basename(os.path.dirname(token_dir))
|
|
233
|
+
subdir = os.path.basename(token_dir)
|
|
234
|
+
parent_id = int(parent_dir) if parent_dir.isdigit() else 0
|
|
235
|
+
subdir_id = int(subdir) if subdir.isdigit() else 0
|
|
236
|
+
if parent_id not in parent_dir_dict:
|
|
237
|
+
parent_dir_dict[parent_id] = {}
|
|
238
|
+
if subdir_id not in parent_dir_dict[parent_id]:
|
|
239
|
+
parent_dir_dict[parent_id][subdir_id] = []
|
|
240
|
+
for cur_path, _, file_names in os.walk(token_dir):
|
|
241
|
+
file_names = [os.path.join(cur_path, file_name) for file_name in file_names]
|
|
242
|
+
parent_dir_dict[parent_id][subdir_id].extend(file_names)
|
|
243
|
+
parent_dir_dict = dict(sorted(parent_dir_dict.items()))
|
|
244
|
+
for _, subdirs in parent_dir_dict.items():
|
|
245
|
+
gathered_files_list.append(subdirs)
|
|
246
|
+
return gathered_files_list
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def init_ge_dump_data_from_bin_path(ge_dump_path):
|
|
250
|
+
"""
|
|
251
|
+
For data like:
|
|
252
|
+
1/Add.Add_2.44.6.1706596912161941,
|
|
253
|
+
1/Cast.Cast_9.19.6.1706596911887829,
|
|
254
|
+
1/ConcatV2D.ConcatV2.42.6.1706596912161117,
|
|
255
|
+
|
|
256
|
+
Return dict:
|
|
257
|
+
{1: {
|
|
258
|
+
'Add_2': '1/Add.Add_2.44.6.1706596912161941',
|
|
259
|
+
'Cast_9': '1/Cast.Cast_9.19.6.1706596911887829',
|
|
260
|
+
'ConcatV2': '1/ConcatV2D.ConcatV2.42.6.1706596912161117',
|
|
261
|
+
}}
|
|
262
|
+
"""
|
|
263
|
+
gathered_files_list = gather_data_with_token_id(ge_dump_path)
|
|
264
|
+
if not gathered_files_list:
|
|
265
|
+
raise Exception("Cannot get ge dump data, because the gathered_files_list is empty.")
|
|
266
|
+
|
|
267
|
+
dump_data_with_token_id_list = []
|
|
268
|
+
for gathered_files in gathered_files_list:
|
|
269
|
+
dump_data_with_token_id = {}
|
|
270
|
+
for token_id, file_list in gathered_files.items():
|
|
271
|
+
cur_dump_data = {}
|
|
272
|
+
for file_name in sorted(file_list):
|
|
273
|
+
if os.path.splitext(file_name)[-1] in DUMP_FILE_FILTER_SUFIX:
|
|
274
|
+
continue
|
|
275
|
+
split_name = os.path.basename(file_name).split(".")
|
|
276
|
+
if len(split_name) < 5:
|
|
277
|
+
logger.warning(f"invalid file name: {file_name}, should contain at least 4 '.'")
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
cur_op_name = ".".join(split_name[1:-3])
|
|
281
|
+
if cur_op_name in cur_dump_data:
|
|
282
|
+
exists_file = cur_dump_data[cur_op_name]
|
|
283
|
+
exists_file_size = os.path.getsize(exists_file)
|
|
284
|
+
cur_file_size = os.path.getsize(file_name)
|
|
285
|
+
keep_one = file_name if cur_file_size > exists_file_size else exists_file
|
|
286
|
+
cur_dump_data[cur_op_name] = keep_one
|
|
287
|
+
logger.warning(
|
|
288
|
+
f"duplicated op name: {cur_op_name}."
|
|
289
|
+
f" [{os.path.basename(file_name)}, {os.path.basename(exists_file)}]."
|
|
290
|
+
f" Will keep the larger one {os.path.basename(keep_one)}."
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
cur_dump_data[cur_op_name] = file_name
|
|
294
|
+
dump_data_with_token_id[token_id] = cur_dump_data
|
|
295
|
+
dump_data_with_token_id_list.append(dump_data_with_token_id)
|
|
296
|
+
return dump_data_with_token_id_list
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def init_fx_dump_data_from_path(fx_dump_path, rank_info_existed=False):
|
|
300
|
+
"""
|
|
301
|
+
For data like:
|
|
302
|
+
1/mm-aten.mm.default.INPUT.0.20240125031118787351.npy,
|
|
303
|
+
1/mm-aten.mm.default.INPUT.1.20240125031118787351.npy,
|
|
304
|
+
1/mm-aten.mm.default.OUTPUT.0.20240125031118787351.npy,
|
|
305
|
+
|
|
306
|
+
Return dict:
|
|
307
|
+
{1: {'mm-aten.mm.default': {
|
|
308
|
+
'input': [
|
|
309
|
+
'1/mm-aten.mm.default.INPUT.0.20240125031118787351.npy',
|
|
310
|
+
'1/mm-aten.mm.default.INPUT.1.20240125031118787351.npy',
|
|
311
|
+
],
|
|
312
|
+
'output': ['1/mm-aten.mm.default.OUTPUT.0.20240125031118787351.npy']
|
|
313
|
+
}}}
|
|
314
|
+
"""
|
|
315
|
+
gathered_files_list = gather_data_with_token_id(fx_dump_path, fx=True, rank_info_existed=rank_info_existed)
|
|
316
|
+
if not gathered_files_list:
|
|
317
|
+
raise Exception("Cannot get fx dump data, because the gathered_files_list is empty.")
|
|
318
|
+
|
|
319
|
+
dump_data_with_token_id_list = []
|
|
320
|
+
for gathered_files in gathered_files_list:
|
|
321
|
+
dump_data_with_token_id = {}
|
|
322
|
+
for token_id, file_list in gathered_files.items():
|
|
323
|
+
cur_dump_data = {}
|
|
324
|
+
for file_path in sorted(file_list):
|
|
325
|
+
if not file_path.endswith("npy"):
|
|
326
|
+
continue
|
|
327
|
+
file_name = os.path.basename(file_path)
|
|
328
|
+
is_input = ".INPUT." in file_name
|
|
329
|
+
cur_op_name = file_name.split(".INPUT." if is_input else ".OUTPUT.")[0]
|
|
330
|
+
cur_op_map = cur_dump_data.get(cur_op_name, {})
|
|
331
|
+
cur_op_map.setdefault("input" if is_input else "output", []).append(file_path)
|
|
332
|
+
cur_dump_data[cur_op_name] = cur_op_map
|
|
333
|
+
if len(cur_dump_data) > 0:
|
|
334
|
+
dump_data_with_token_id[token_id - 1] = cur_dump_data # For FX data, token starts from 1, while GE is 0
|
|
335
|
+
dump_data_with_token_id_list.append(dump_data_with_token_id)
|
|
336
|
+
return dump_data_with_token_id_list
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def compare_single_data(golden_path, my_path, token_id=0, golden_data=None, my_data=None):
|
|
340
|
+
data_info = BasicDataInfo(golden_path, my_path, token_id)
|
|
341
|
+
return fill_row_data(data_info, loaded_my_data=my_data, loaded_golden_data=golden_data)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# Comparing GE with FX
|
|
345
|
+
def filter_valid_fx_desc_tensor_info(desc_key, desc_value):
|
|
346
|
+
"""Valid one like: 'attr': {'key': '_fx_tensor_name', 'value': {'s': 'add_1-aten.add.Tensor.OUTPUT.0'}}"""
|
|
347
|
+
if not (desc_key == "attr" or desc_key.startswith("attr#")) or not isinstance(desc_value, dict):
|
|
348
|
+
return False
|
|
349
|
+
if desc_value.get("key", None) != "_fx_tensor_name" or not isinstance(desc_value.get("value", None), dict):
|
|
350
|
+
return False
|
|
351
|
+
if not isinstance(desc_value.get("value", {}).get("s", None), str):
|
|
352
|
+
return False
|
|
353
|
+
return True
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def get_all_ops_from_fusion_op(op_name, graph_map_dict, ge_dump_data):
|
|
357
|
+
all_ops = []
|
|
358
|
+
while len(op_name) > 0:
|
|
359
|
+
cur_op_name = find_longest_name(op_name, graph_map_dict, ge_dump_data, ge_dump_data)
|
|
360
|
+
if cur_op_name is None or cur_op_name not in graph_map_dict:
|
|
361
|
+
logger.debug(f"Failed parsing ge op name: {cur_op_name}.Compare manually if required.")
|
|
362
|
+
break
|
|
363
|
+
all_ops.append(cur_op_name)
|
|
364
|
+
op_name = op_name[len(cur_op_name):]
|
|
365
|
+
return all_ops
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def compare_ge_with_fx(graph_map, ge_dump_data, fx_dump_data, token_id=0):
|
|
369
|
+
gathered_row_data = []
|
|
370
|
+
graph_map_dict = {
|
|
371
|
+
graph["op"]["name"]: graph["op"]
|
|
372
|
+
for graph in graph_map
|
|
373
|
+
if "op" in graph and "name" in graph["op"]
|
|
374
|
+
}
|
|
375
|
+
ge_dump_data = sort_ge_dump_data(ge_dump_data, graph_map)
|
|
376
|
+
for op_name, my_path in ge_dump_data.items():
|
|
377
|
+
all_ops = get_all_ops_from_fusion_op(op_name, graph_map_dict, ge_dump_data)
|
|
378
|
+
if len(all_ops) == 1:
|
|
379
|
+
op_info = graph_map_dict.get(all_ops[0])
|
|
380
|
+
gathered_row_data.extend(compare_ge_with_fx_single_op(op_info, fx_dump_data, op_name, my_path, token_id))
|
|
381
|
+
elif len(all_ops) > 1:
|
|
382
|
+
first_op_info = graph_map_dict.get(all_ops[0])
|
|
383
|
+
last_op_info = graph_map_dict.get(all_ops[-1])
|
|
384
|
+
__args = [first_op_info, last_op_info, fx_dump_data, op_name, my_path, token_id]
|
|
385
|
+
gathered_row_data.extend(compare_ge_with_fx_multiple_ops(*__args))
|
|
386
|
+
else:
|
|
387
|
+
op_type = os.path.basename(my_path).split(".")[0]
|
|
388
|
+
if "Cast" in op_type or "TransData" in op_type:
|
|
389
|
+
ge_inputs, ge_outputs = parse_torchair_dump_data(my_path)
|
|
390
|
+
logger.debug(f"ge_inputs length: {len(ge_inputs)}")
|
|
391
|
+
logger.debug(f"ge_outputs length:, {len(ge_outputs)}")
|
|
392
|
+
gathered_row_data.extend(compare_specials_private_ops(ge_inputs, ge_outputs, token_id, my_path))
|
|
393
|
+
|
|
394
|
+
return gathered_row_data
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def compare_ge_with_fx_single_op(op_info, fx_dump_data, op_name, my_path, token_id=0):
|
|
398
|
+
gathered_row_data = []
|
|
399
|
+
for op_key, op_value in op_info.items():
|
|
400
|
+
if not (op_key == "output_desc" or op_key.startswith("output_desc#")) or not isinstance(op_value, dict):
|
|
401
|
+
continue
|
|
402
|
+
for out_key, out_value in op_value.items():
|
|
403
|
+
if not filter_valid_fx_desc_tensor_info(out_key, out_value):
|
|
404
|
+
continue
|
|
405
|
+
fx_tensor_name = out_value.get("value", {}).get("s", None)
|
|
406
|
+
if fx_tensor_name.split(".")[-2] == "OUTPUT":
|
|
407
|
+
fx_tensor_name = ".".join(fx_tensor_name.split(".")[:-2])
|
|
408
|
+
if fx_tensor_name not in fx_dump_data:
|
|
409
|
+
logger.warning(f"FX data missing, GE tensor name: {op_name}, FX tensor name: {fx_tensor_name}")
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
ge_inputs, ge_outputs = parse_torchair_dump_data(my_path)
|
|
413
|
+
fx_inputs = fx_dump_data.get(fx_tensor_name, {}).get("input", [])
|
|
414
|
+
fx_outputs = fx_dump_data.get(fx_tensor_name, {}).get("output", [])
|
|
415
|
+
logger.debug(f"ge_inputs length: {len(ge_inputs)}, fx_inputs length:, {len(fx_inputs)}")
|
|
416
|
+
logger.debug(f"ge_outputs length: {len(ge_outputs)}, fx_outputs length:, {len(fx_outputs)}")
|
|
417
|
+
gathered_row_data = compare_ops((fx_inputs, fx_outputs), (ge_inputs, ge_outputs), token_id, my_path)
|
|
418
|
+
|
|
419
|
+
return gathered_row_data
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def compare_ge_with_fx_multiple_ops(first_op_info, last_op_info, fx_dump_data, op_name, my_path, token_id):
|
|
423
|
+
gathered_row_data = []
|
|
424
|
+
gathered_row_data.extend(
|
|
425
|
+
compare_ge_with_fx_multiple_ops_details(first_op_info, fx_dump_data, op_name, my_path, "input", token_id)
|
|
426
|
+
)
|
|
427
|
+
gathered_row_data.extend(
|
|
428
|
+
compare_ge_with_fx_multiple_ops_details(last_op_info, fx_dump_data, op_name, my_path, "output", token_id)
|
|
429
|
+
)
|
|
430
|
+
return gathered_row_data
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def compare_ge_with_fx_multiple_ops_details(op_info: dict, *args):
|
|
434
|
+
fx_dump_data, op_name, my_path, input_or_output, token_id = args
|
|
435
|
+
gathered_row_data = []
|
|
436
|
+
for op_key, op_value in op_info.items():
|
|
437
|
+
if not (op_key == "output_desc" or op_key.startswith("output_desc#")) or not isinstance(op_value, dict):
|
|
438
|
+
continue
|
|
439
|
+
for out_key, out_value in op_value.items():
|
|
440
|
+
if not filter_valid_fx_desc_tensor_info(out_key, out_value):
|
|
441
|
+
continue
|
|
442
|
+
fx_tensor_name = out_value.get("value", {}).get("s", None)
|
|
443
|
+
if fx_tensor_name.split(".")[-2] == "OUTPUT":
|
|
444
|
+
fx_tensor_name = ".".join(fx_tensor_name.split(".")[:-2])
|
|
445
|
+
if fx_tensor_name not in fx_dump_data:
|
|
446
|
+
logger.warning(f"FX data missing, GE tensor name: {op_name}, FX tensor name: {fx_tensor_name}")
|
|
447
|
+
continue
|
|
448
|
+
ge_inputs, ge_outputs = parse_torchair_dump_data(my_path)
|
|
449
|
+
fx_inputs_or_outputs = fx_dump_data.get(fx_tensor_name, {}).get(input_or_output, [])
|
|
450
|
+
ge_input_or_output_path = ""
|
|
451
|
+
ge_inputs_or_outputs = []
|
|
452
|
+
if input_or_output == "input":
|
|
453
|
+
logger.debug(f"ge_inputs length: {len(ge_inputs)}, fx_inputs length:, {len(fx_inputs_or_outputs)}")
|
|
454
|
+
ge_input_or_output_path = "inputs"
|
|
455
|
+
ge_inputs_or_outputs = ge_inputs
|
|
456
|
+
elif input_or_output == "output":
|
|
457
|
+
logger.debug(f"ge_outputs length: {len(ge_outputs)}, fx_outputs length:, {len(fx_inputs_or_outputs)}")
|
|
458
|
+
ge_input_or_output_path = "outputs"
|
|
459
|
+
ge_inputs_or_outputs = ge_outputs
|
|
460
|
+
for cur_id, (fx_input_or_output, ge_input_or_output) in enumerate(
|
|
461
|
+
zip(fx_inputs_or_outputs, ge_inputs_or_outputs)
|
|
462
|
+
):
|
|
463
|
+
cur_ge_data = "{},{},{}".format(my_path, ge_input_or_output_path, cur_id)
|
|
464
|
+
row_data = compare_single_data(fx_input_or_output, cur_ge_data, token_id, my_data=ge_input_or_output)
|
|
465
|
+
gathered_row_data.append(row_data)
|
|
466
|
+
|
|
467
|
+
return gathered_row_data
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def compare_specials_private_ops(ge_inputs, ge_outputs, token_id, my_path):
|
|
471
|
+
gathered_row_data = []
|
|
472
|
+
for cur_id, (ge_input, ge_output) in enumerate(zip(ge_inputs, ge_outputs)):
|
|
473
|
+
cur_ge_input_data = f"{my_path},inputs,{cur_id}"
|
|
474
|
+
cur_ge_output_data = f"{my_path},outputs,{cur_id}"
|
|
475
|
+
row_data = compare_single_data(cur_ge_input_data, cur_ge_output_data, token_id, ge_input, ge_output)
|
|
476
|
+
gathered_row_data.append(row_data)
|
|
477
|
+
|
|
478
|
+
return gathered_row_data
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def compare_ops(fx_tuple, ge_tuple, token_id, my_path):
|
|
482
|
+
gathered_row_data = []
|
|
483
|
+
for cur_id, (fx_input, ge_input) in enumerate(zip(fx_tuple[0], ge_tuple[0])):
|
|
484
|
+
cur_ge_data = f"{my_path},inputs,{cur_id}"
|
|
485
|
+
row_data = compare_single_data(fx_input, cur_ge_data, token_id, my_data=ge_input)
|
|
486
|
+
gathered_row_data.append(row_data)
|
|
487
|
+
for cur_id, (fx_output, ge_output) in enumerate(zip(fx_tuple[1], ge_tuple[1])):
|
|
488
|
+
cur_ge_data = f"{my_path},outputs,{cur_id}"
|
|
489
|
+
row_data = compare_single_data(fx_output, cur_ge_data, token_id, my_data=ge_output)
|
|
490
|
+
gathered_row_data.append(row_data)
|
|
491
|
+
|
|
492
|
+
return gathered_row_data
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
# Comparing fused GE with GE
|
|
496
|
+
def get_all_op_input_names(op_info):
|
|
497
|
+
inputs = [vv for kk, vv in op_info.items() if kk == "input" or kk.startswith("input#")]
|
|
498
|
+
return [":".join(ii.split(":")[:-1]) for ii in inputs]
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def find_longest_name(op_name, op_map, fused_ge_dump_data, ge_dump_data):
|
|
502
|
+
if op_name in op_map:
|
|
503
|
+
return op_name
|
|
504
|
+
op_name_len = len(op_name)
|
|
505
|
+
for idx in range(1, op_name_len):
|
|
506
|
+
cur_op_name = op_name[:-idx]
|
|
507
|
+
if cur_op_name in op_map:
|
|
508
|
+
return cur_op_name
|
|
509
|
+
if cur_op_name in fused_ge_dump_data or cur_op_name in ge_dump_data:
|
|
510
|
+
return None # op_name in dump data but not op_map, abandon
|
|
511
|
+
return None
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def gather_fused_op_data(fused_op_name, op_map, fused_ge_dump_data, ge_dump_data):
|
|
515
|
+
gathered_input_names, gathered_inputs, gatherd_input_pathes, gathered_ops = [], [], [], []
|
|
516
|
+
output_path, op_outputs = None, []
|
|
517
|
+
while len(fused_op_name) > 0:
|
|
518
|
+
cur_op_name = find_longest_name(fused_op_name, op_map, fused_ge_dump_data, ge_dump_data)
|
|
519
|
+
if cur_op_name is None or cur_op_name not in op_map:
|
|
520
|
+
logger.warning(f"Failed parsing fused op name: {fused_op_name}. Compare manually if required.")
|
|
521
|
+
break
|
|
522
|
+
cur_input_names = get_all_op_input_names(op_map[cur_op_name])
|
|
523
|
+
|
|
524
|
+
if cur_op_name in ge_dump_data:
|
|
525
|
+
cur_path = ge_dump_data[cur_op_name]
|
|
526
|
+
op_inputs, op_outputs = parse_torchair_dump_data(cur_path)
|
|
527
|
+
min_inputs_len = min(len(cur_input_names), len(op_inputs))
|
|
528
|
+
cur_input_names, op_inputs = cur_input_names[:min_inputs_len], op_inputs[:min_inputs_len]
|
|
529
|
+
input_pathes = [",".join([cur_path, "inputs", str(idx)]) for idx in range(min_inputs_len)]
|
|
530
|
+
output_path = cur_path # Till get the last op path
|
|
531
|
+
else:
|
|
532
|
+
logger.warning(
|
|
533
|
+
f"No dump data for op: {cur_op_name}. Seldom should this happen. Input data matching may be incorrect."
|
|
534
|
+
)
|
|
535
|
+
empty_data = np.array([], dtype="float32")
|
|
536
|
+
op_inputs = [empty_data] * len(cur_input_names)
|
|
537
|
+
input_pathes = [""] * len(cur_input_names)
|
|
538
|
+
|
|
539
|
+
gathered_input_names.extend(cur_input_names)
|
|
540
|
+
gathered_ops.append(cur_op_name)
|
|
541
|
+
gathered_inputs.extend(op_inputs)
|
|
542
|
+
gatherd_input_pathes.extend(input_pathes)
|
|
543
|
+
fused_op_name = fused_op_name[len(cur_op_name):]
|
|
544
|
+
|
|
545
|
+
filtered_input_names, filtered_inputs, filtered_input_pathes = [], [], []
|
|
546
|
+
for input_name, inputs, input_path in zip(gathered_input_names, gathered_inputs, gatherd_input_pathes):
|
|
547
|
+
if input_name not in gathered_ops:
|
|
548
|
+
filtered_input_names.append(input_name)
|
|
549
|
+
filtered_input_pathes.append(input_path)
|
|
550
|
+
filtered_inputs.append(inputs)
|
|
551
|
+
return (filtered_inputs, filtered_input_pathes), (op_outputs, output_path) # op_outputs is just the last op output
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def compare_ge_with_ge(graph_map, fused_ge_dump_data, ge_dump_data, token_id=0):
|
|
555
|
+
graph_map_dict = {ii["op"]["name"]: ii["op"] for ii in graph_map if "op" in ii and "name" in ii["op"]}
|
|
556
|
+
fused_ge_dump_data = sort_ge_dump_data(fused_ge_dump_data, graph_map)
|
|
557
|
+
gathered_row_data = []
|
|
558
|
+
for op_name, my_path in fused_ge_dump_data.items():
|
|
559
|
+
is_fused_op = os.path.basename(my_path).startswith(FUSION_OP_TYPE)
|
|
560
|
+
if is_fused_op:
|
|
561
|
+
(golden_inputs, golden_input_pathes), (golden_outputs, golden_output_path) = gather_fused_op_data(
|
|
562
|
+
op_name, graph_map_dict, fused_ge_dump_data, ge_dump_data
|
|
563
|
+
)
|
|
564
|
+
elif op_name in ge_dump_data:
|
|
565
|
+
golden_path = ge_dump_data[op_name]
|
|
566
|
+
golden_inputs, golden_outputs = parse_torchair_dump_data(golden_path)
|
|
567
|
+
golden_input_pathes = [golden_path] * len(golden_inputs)
|
|
568
|
+
golden_output_path = golden_path
|
|
569
|
+
else:
|
|
570
|
+
logger.warning(f"Golden data missing, My tensor name: {op_name}")
|
|
571
|
+
continue
|
|
572
|
+
|
|
573
|
+
my_inputs, my_outputs = parse_torchair_dump_data(my_path)
|
|
574
|
+
logger.debug(f"golden_inputs length: {len(golden_inputs)}, my_inputs length:, {len(my_inputs)}")
|
|
575
|
+
logger.debug(f"golden_outputs length: {len(golden_outputs)}, my_outputs length:, {len(my_outputs)}")
|
|
576
|
+
|
|
577
|
+
for cur_id, (golden_input, my_input, golden_input_path) in enumerate(
|
|
578
|
+
zip(golden_inputs, my_inputs, golden_input_pathes)
|
|
579
|
+
):
|
|
580
|
+
cur_ge_data = f"{my_path},inputs,{cur_id}"
|
|
581
|
+
if ",inputs," not in golden_output_path:
|
|
582
|
+
golden_output_path = f"{golden_output_path},inputs,{cur_id}"
|
|
583
|
+
row_data = compare_single_data(
|
|
584
|
+
golden_input_path, cur_ge_data, token_id, golden_data=golden_input, my_data=my_input
|
|
585
|
+
)
|
|
586
|
+
gathered_row_data.append(row_data)
|
|
587
|
+
for cur_id, (golden_output, my_output) in enumerate(zip(golden_outputs, my_outputs)):
|
|
588
|
+
cur_ge_data = f"{my_path},outputs,{cur_id}"
|
|
589
|
+
golden_output_path = f"{golden_output_path},outputs,{cur_id}"
|
|
590
|
+
row_data = compare_single_data(
|
|
591
|
+
golden_output_path, cur_ge_data, token_id, golden_data=golden_output, my_data=my_output
|
|
592
|
+
)
|
|
593
|
+
gathered_row_data.append(row_data)
|
|
594
|
+
return gathered_row_data
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def sort_ge_dump_data(dump_data, graph_map):
|
|
598
|
+
graph_map_sort = {graph["op"]["name"]: id for id, graph in enumerate(graph_map)}
|
|
599
|
+
sort_ops_list = sorted(dump_data, key=lambda x: graph_map_sort.get(x, -1))
|
|
600
|
+
ge_dump_data = OrderedDict((op_name, dump_data[op_name]) for op_name in sort_ops_list)
|
|
601
|
+
return ge_dump_data
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def sort_by_timestamp(gathered_row_data):
|
|
605
|
+
"""
|
|
606
|
+
gathered_row_data为保存比对结果的列表,列表里每个元素都是一个字典
|
|
607
|
+
1. 取每个字典的'my_data_path'字段x['my_data_path'],该字段的值如:OpType.OpName.12.7.1734070081497686,inputs,0
|
|
608
|
+
2. 按'.'分隔字符串x['my_data_path'].split('.') —> ['OpType', 'OpName', '12', '7', '1734070081497686,inputs,0']
|
|
609
|
+
3. x['my_data_path'].split('.')[-1]取最后一个元素为'1734070081497686,inputs,0',该值包含时间戳(1734070081497686)和输入输出信息(inputs,0)
|
|
610
|
+
4. 按'1734070081497686,inputs,0'对比对结果重新排序得到sorted_gathered_row_data列表
|
|
611
|
+
5. 返回排序后的结果
|
|
612
|
+
"""
|
|
613
|
+
sorted_gathered_row_data = sorted(gathered_row_data, key=lambda x: x['my_data_path'].split('.')[-1])
|
|
614
|
+
return sorted_gathered_row_data
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
# Main entrance
|
|
618
|
+
def acc_compare(golden_path, my_path, output_path='./', rank_id=None, rank_info_existed=False):
|
|
619
|
+
set_msaccucmp_path_from_cann()
|
|
620
|
+
|
|
621
|
+
if not get_torchair_ge_graph_path(my_path):
|
|
622
|
+
raise Exception("Can not get ge graph, Please check whether the input path contains graph.")
|
|
623
|
+
|
|
624
|
+
if rank_info_existed:
|
|
625
|
+
if rank_id is None:
|
|
626
|
+
golden_data_ranks = set()
|
|
627
|
+
my_data_ranks = set()
|
|
628
|
+
|
|
629
|
+
for subdir in os.listdir(golden_path):
|
|
630
|
+
rank_id = get_rank_id_from_torchair_data(subdir)
|
|
631
|
+
if os.path.isdir(os.path.join(golden_path, subdir)) and rank_id != -1:
|
|
632
|
+
golden_data_ranks.add(rank_id)
|
|
633
|
+
|
|
634
|
+
for subdir in os.listdir(my_path):
|
|
635
|
+
rank_id = get_rank_id_from_torchair_data(subdir)
|
|
636
|
+
if os.path.isdir(os.path.join(my_path, subdir)) and rank_id != -1:
|
|
637
|
+
my_data_ranks.add(rank_id)
|
|
638
|
+
|
|
639
|
+
compared_ranks = list(golden_data_ranks & my_data_ranks)
|
|
640
|
+
if not compared_ranks:
|
|
641
|
+
raise Exception("No common rank data in golden_path and my_path.")
|
|
642
|
+
else:
|
|
643
|
+
compared_ranks = [rank_id]
|
|
644
|
+
else:
|
|
645
|
+
compared_ranks = [-1]
|
|
646
|
+
|
|
647
|
+
args = [(golden_path, my_path, output_path, rid) for rid in compared_ranks]
|
|
648
|
+
# If only a single rank needs comparison, run synchronously so that any exceptions
|
|
649
|
+
# raised inside `acc_compare_once` are propagated directly.
|
|
650
|
+
if len(args) == 1:
|
|
651
|
+
acc_compare_once(args[0])
|
|
652
|
+
else:
|
|
653
|
+
# Use multiprocessing for multiple ranks but make sure to propagate exceptions
|
|
654
|
+
processes_pool = Pool(min(len(args), int(cpu_count() * 1.3)))
|
|
655
|
+
async_results = [processes_pool.apply_async(save_compare_once, (arg,)) for arg in args]
|
|
656
|
+
processes_pool.close()
|
|
657
|
+
# Ensure that exceptions in worker processes are not silenced. Calling `get()` will
|
|
658
|
+
# re-raise any exception occurred in the worker process in the main process.
|
|
659
|
+
for res in async_results:
|
|
660
|
+
res.get()
|
|
661
|
+
processes_pool.join()
|
|
662
|
+
|
|
663
|
+
return output_path
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def save_compare_once(args):
|
|
667
|
+
try:
|
|
668
|
+
return acc_compare_once(args)
|
|
669
|
+
except Exception as e:
|
|
670
|
+
error_msg = f"Error in acc_compare_once: {str(e)}"
|
|
671
|
+
raise ValueError(error_msg) from e
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def compare_torchair_mode(args):
|
|
675
|
+
"""
|
|
676
|
+
Entry point used by the CLI to trigger torchair accuracy compare.
|
|
677
|
+
"""
|
|
678
|
+
my_path = os.path.realpath(args.target_path)
|
|
679
|
+
golden_path = os.path.realpath(args.golden_path)
|
|
680
|
+
_validate_read_path(my_path)
|
|
681
|
+
_validate_read_path(golden_path)
|
|
682
|
+
|
|
683
|
+
rank_arg = getattr(args, "rank", None)
|
|
684
|
+
rank_id = None
|
|
685
|
+
if rank_arg is not None:
|
|
686
|
+
rank_str = str(rank_arg).strip()
|
|
687
|
+
if not rank_str.isdigit():
|
|
688
|
+
logger.error("Argument --rank only supports a single integer when mode=='torchair'.")
|
|
689
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR, "Invalid rank parameter for torchair mode.")
|
|
690
|
+
rank_id = int(rank_str)
|
|
691
|
+
|
|
692
|
+
rank_info_existed = _has_rank_directory(my_path)
|
|
693
|
+
if not rank_info_existed and rank_id is not None:
|
|
694
|
+
logger.warning('The directory structure of torchair data is old, the rank parameter will not take effect.')
|
|
695
|
+
logger.info(f"[compare_torchair] start comparing, golden_path: {golden_path}, target_path: {my_path}")
|
|
696
|
+
return acc_compare(golden_path, my_path, args.output_path, rank_id, rank_info_existed)
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
def acc_compare_once(*args):
|
|
700
|
+
dir_of_golden_path, dir_of_my_path, output_path, rank_id = args[0]
|
|
701
|
+
if rank_id != -1:
|
|
702
|
+
subdirs = []
|
|
703
|
+
for subdir in os.listdir(dir_of_golden_path):
|
|
704
|
+
is_dir = os.path.isdir(os.path.join(dir_of_golden_path, subdir))
|
|
705
|
+
if is_dir and subdir.startswith('worldsize') and subdir.endswith(f'rank{rank_id}'):
|
|
706
|
+
subdirs.append(subdir)
|
|
707
|
+
if not subdirs:
|
|
708
|
+
raise Exception(f'Can not get golden data in rank {rank_id}')
|
|
709
|
+
golden_path = os.path.join(dir_of_golden_path, subdirs[-1])
|
|
710
|
+
|
|
711
|
+
subdirs = []
|
|
712
|
+
for subdir in os.listdir(dir_of_my_path):
|
|
713
|
+
is_dir = os.path.isdir(os.path.join(dir_of_my_path, subdir))
|
|
714
|
+
if is_dir and subdir.startswith('worldsize') and subdir.endswith(f'rank{rank_id}'):
|
|
715
|
+
subdirs.append(subdir)
|
|
716
|
+
if not subdirs:
|
|
717
|
+
raise Exception(f'Can not get my data in rank {rank_id}')
|
|
718
|
+
my_path = os.path.join(dir_of_my_path, subdirs[-1])
|
|
719
|
+
else:
|
|
720
|
+
golden_path = dir_of_golden_path
|
|
721
|
+
my_path = dir_of_my_path
|
|
722
|
+
|
|
723
|
+
ge_graph_path = get_torchair_ge_graph_path(dir_of_my_path, rank_id)
|
|
724
|
+
|
|
725
|
+
if not ge_graph_path:
|
|
726
|
+
raise Exception("Can not get ge graph, Please check whether the input path contains graph.")
|
|
727
|
+
|
|
728
|
+
logger.info(f"[compare_torchair], golden_path: {golden_path}, my_path: {my_path}, ge_graph_path: {ge_graph_path}")
|
|
729
|
+
|
|
730
|
+
graph_map_list = []
|
|
731
|
+
for path in ge_graph_path:
|
|
732
|
+
graph_map_list.append(parse_pbtxt_to_dict(path))
|
|
733
|
+
|
|
734
|
+
my_dump_data_list = init_ge_dump_data_from_bin_path(my_path)
|
|
735
|
+
|
|
736
|
+
is_golden_fx = get_torchair_ge_graph_path(dir_of_golden_path) is None
|
|
737
|
+
if is_golden_fx:
|
|
738
|
+
logger.info("Comparing GE with FX")
|
|
739
|
+
golden_dump_data_list = init_fx_dump_data_from_path(golden_path, rank_id != -1)
|
|
740
|
+
else:
|
|
741
|
+
logger.info("Comparing GE with GE")
|
|
742
|
+
golden_dump_data_list = init_ge_dump_data_from_bin_path(golden_path)
|
|
743
|
+
|
|
744
|
+
graph_map_list_len = len(graph_map_list)
|
|
745
|
+
for i in range(graph_map_list_len):
|
|
746
|
+
logger.info(f"All token ids in my_dump_data: {my_dump_data_list[i].keys()}")
|
|
747
|
+
logger.info(f"All token ids in golden_dump_data: {golden_dump_data_list[i].keys()}")
|
|
748
|
+
graph_map = graph_map_list[i]
|
|
749
|
+
my_dump_data = my_dump_data_list[i]
|
|
750
|
+
golden_dump_data = golden_dump_data_list[i]
|
|
751
|
+
|
|
752
|
+
gathered_row_data = []
|
|
753
|
+
for token_id in my_dump_data:
|
|
754
|
+
if token_id not in golden_dump_data:
|
|
755
|
+
logger.warning(f"My token_id {token_id} not found in golden dump data")
|
|
756
|
+
continue
|
|
757
|
+
logger.info(f"Comparing token_id: {token_id}")
|
|
758
|
+
if is_golden_fx:
|
|
759
|
+
row_data = compare_ge_with_fx(graph_map, my_dump_data[token_id], golden_dump_data[token_id], token_id)
|
|
760
|
+
else:
|
|
761
|
+
row_data = compare_ge_with_ge(graph_map, my_dump_data[token_id], golden_dump_data[token_id], token_id)
|
|
762
|
+
gathered_row_data.extend(row_data)
|
|
763
|
+
sorted_gathered_row_data = sort_by_timestamp(gathered_row_data)
|
|
764
|
+
save_compare_result_to_csv(sorted_gathered_row_data, output_path, rank_id=rank_id)
|