mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +24 -193
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +97 -74
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +1915 -3287
- mindspore/common/api.py +341 -354
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +297 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +214 -560
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +108 -76
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +93 -144
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +131 -700
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +194 -109
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +218 -24
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1250 -176
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +16 -12
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/math_ops.py +4 -4
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +7 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
- mindspore/ops/auto_generate/gen_extend_func.py +281 -135
- mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
- mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1629 -2345
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3035 -3705
- mindspore/ops/function/nn_func.py +676 -241
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +204 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +6 -4
- mindspore/ops/functional_overload.py +547 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +10 -5
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +37 -22
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +221 -23
- mindspore/ops/operations/debug_ops.py +115 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +65 -191
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +232 -13
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +133 -6
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +656 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -61
- mindspore/parallel/transform_safetensors.py +287 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +25 -8
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +35 -7
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +176 -103
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -15,24 +15,27 @@
|
|
|
15
15
|
"""Checkpoint related classes and functions."""
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
|
+
from mindspore.utils import _tft_handler
|
|
18
19
|
from mindspore.train.serialization import save_checkpoint
|
|
19
|
-
from mindspore.parallel._utils import _get_device_num
|
|
20
|
-
from mindspore import _checkparam as Validator
|
|
21
20
|
from mindspore.train.callback._callback import Callback
|
|
22
|
-
from mindspore import context
|
|
21
|
+
from mindspore import context, ops
|
|
23
22
|
from mindspore.common.parameter import Parameter
|
|
24
23
|
from mindspore.common.tensor import Tensor
|
|
25
24
|
from mindspore.communication import get_rank, get_group_size
|
|
26
25
|
from mindspore import log as logger
|
|
27
26
|
from mindspore.train.serialization import _get_cur_rank_dp
|
|
28
27
|
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
|
|
28
|
+
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
|
|
29
29
|
from mindspore._c_expression import clean_tdt_channel
|
|
30
30
|
from mindspore._c_expression import send_recv, reset_params
|
|
31
31
|
from mindspore._c_expression import CollectiveManager
|
|
32
32
|
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
33
|
-
from mindspore._c_expression import
|
|
33
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
34
|
+
from mindspore.ops.operations.manually_defined._inner import TensorReport
|
|
34
35
|
import mindspore
|
|
35
36
|
import mindspore.common.dtype as mstype
|
|
37
|
+
from mindspore.parallel._recovery_context import _set_recovery_context
|
|
38
|
+
|
|
36
39
|
|
|
37
40
|
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
38
41
|
""" Common func to generate ckpt dir name."""
|
|
@@ -40,30 +43,38 @@ def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
|
40
43
|
mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
|
|
41
44
|
return os.path.join(ckpt_save_path, mid_dir)
|
|
42
45
|
|
|
46
|
+
|
|
43
47
|
def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
44
48
|
""" Callback used for TFT save ckpt function when errors occur."""
|
|
45
49
|
logger.info("Enter _save_checkpoint_on_failure function")
|
|
46
|
-
if not cb_ctx._is_params_consistent():
|
|
50
|
+
if not cb_ctx._is_params_consistent(): # pylint: disable=W0212
|
|
47
51
|
raise RuntimeError("Can't save parameters, because they are left in inconsistent state!")
|
|
52
|
+
cb_params = args
|
|
53
|
+
# we record the current step and epoch num in on_train_step_end, so we can just reset it here
|
|
54
|
+
cb_params.cur_step_num = cb_ctx.cur_step_num
|
|
55
|
+
cb_params.cur_epoch_num = cb_ctx.cur_epoch_num
|
|
56
|
+
if cb_params.optimizer is not None:
|
|
57
|
+
cb_params.optimizer.global_step = cb_ctx.global_step
|
|
58
|
+
if hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
|
|
59
|
+
cb_params.network.optimizer.global_step = cb_ctx.global_step
|
|
60
|
+
append_dict = {}
|
|
61
|
+
append_dict["__exception_save__"] = True
|
|
62
|
+
# if user has provided a custom save callback, use it
|
|
63
|
+
if cb_ctx.save_cb:
|
|
64
|
+
cb_ctx.save_cb(cb_params, append_dict)
|
|
65
|
+
logger.info("Finish _save_checkpoint_on_failure function")
|
|
66
|
+
return
|
|
48
67
|
|
|
68
|
+
# if user has not provided a custom save callback, use default save logic
|
|
49
69
|
ckpt_save_path = cb_ctx.ckpt_save_path
|
|
50
|
-
cb_params = args
|
|
51
70
|
cur_rank = get_rank()
|
|
52
|
-
|
|
71
|
+
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
53
72
|
cur_epoch_num = cb_params.cur_epoch_num
|
|
54
|
-
batch_num = cb_params.batch_num
|
|
55
|
-
if cur_step_num > step:
|
|
56
|
-
cur_epoch_num = (step - 1) // batch_num + 1
|
|
57
|
-
step_num_in_epoch = int((step - 1) % batch_num + 1)
|
|
58
|
-
|
|
59
|
-
append_dict = {}
|
|
60
73
|
append_dict["epoch_num"] = cur_epoch_num
|
|
61
|
-
append_dict["step_num"] =
|
|
74
|
+
append_dict["step_num"] = cb_params.cur_step_num
|
|
62
75
|
append_dict["cur_rank"] = cur_rank
|
|
63
|
-
append_dict["batch_num"] = batch_num
|
|
64
|
-
append_dict["
|
|
65
|
-
|
|
66
|
-
append_dict["global_step"] = Parameter([cb_ctx.global_step])
|
|
76
|
+
append_dict["batch_num"] = cb_params.batch_num
|
|
77
|
+
append_dict["global_step"] = cb_ctx.global_step
|
|
67
78
|
outputs = cb_params.net_outputs
|
|
68
79
|
if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
|
|
69
80
|
append_dict["loss_scale"] = outputs[2]
|
|
@@ -76,49 +87,63 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
|
76
87
|
integrated_save=False, append_dict=append_dict)
|
|
77
88
|
logger.info("Finish _save_checkpoint_on_failure function")
|
|
78
89
|
|
|
90
|
+
|
|
79
91
|
def _rename_save_result(step, cb_ctx):
|
|
80
92
|
""" Callback used for TFT rename function after ckpt save callback was finished and successful."""
|
|
81
93
|
logger.info("Enter _rename_save_result function")
|
|
94
|
+
if cb_ctx.save_cb:
|
|
95
|
+
logger.info("User's save callback is provided, skip rename")
|
|
96
|
+
return
|
|
82
97
|
tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
|
|
83
98
|
fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
|
|
84
99
|
|
|
85
100
|
os.rename(tmp_dir, fin_dir)
|
|
86
101
|
logger.info("Finish _rename_save_result function")
|
|
87
102
|
|
|
103
|
+
|
|
88
104
|
def _tft_exit_cb(ctx):
|
|
105
|
+
"""Callback used for TFT exit function."""
|
|
89
106
|
logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
|
|
90
107
|
_tft_sem_post()
|
|
91
|
-
os._exit(1)
|
|
108
|
+
os._exit(1) # pylint: disable=W0212
|
|
92
109
|
|
|
93
110
|
|
|
94
111
|
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
|
|
95
112
|
""" Callback used for TFT repair function."""
|
|
96
|
-
logger.
|
|
97
|
-
if(repair_info["repair_type"]
|
|
98
|
-
|
|
99
|
-
logger.
|
|
113
|
+
logger.warning("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
|
|
114
|
+
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
|
|
115
|
+
cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
|
|
116
|
+
logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
|
|
100
117
|
_repair_device(cb_ctx.device_id)
|
|
101
118
|
|
|
102
|
-
if(repair_info["repair_type"]
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
{}, src_rank:{}, dst_rank: {}".format(
|
|
119
|
+
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
|
|
120
|
+
cb_ctx.tft.RepairType.RT_SEND.value,
|
|
121
|
+
cb_ctx.tft.RepairType.RT_RECV_REPAIR.value)):
|
|
122
|
+
logger.warning("Enter _tft_repair_callback SEND_RECV repair type:{}, src_rank:{}, dst_rank: {}".format(
|
|
123
|
+
repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
|
|
106
124
|
cb_params = args
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
125
|
+
if repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value:
|
|
126
|
+
for i in range(len(repair_info["src"])):
|
|
127
|
+
src_rank = repair_info["src"][i]
|
|
128
|
+
dst_rank = repair_info["dst"][i]
|
|
129
|
+
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
|
|
130
|
+
raise ValueError("Call send_recv failed.")
|
|
131
|
+
else:
|
|
132
|
+
src_rank = repair_info["src"][0]
|
|
133
|
+
dst_rank = repair_info["dst"][0]
|
|
134
|
+
if send_recv(cb_params.train_network.trainable_params(), src_rank, dst_rank) != 0:
|
|
135
|
+
raise ValueError("Call send_recv failed.")
|
|
136
|
+
logger.warning("Finish _tft_repair_callback")
|
|
112
137
|
|
|
113
138
|
|
|
114
139
|
def _tft_clean_callback(is_uce_error, args, ctx):
|
|
115
140
|
""" Callback used for TFT clean function."""
|
|
116
|
-
logger.
|
|
141
|
+
logger.warning("Enter _tft_clean_callback")
|
|
117
142
|
ret = 0
|
|
118
143
|
if is_uce_error:
|
|
119
144
|
_get_uce_mem_info(ctx.device_id)
|
|
120
145
|
err_strategy = _get_uce_process_strategy()
|
|
121
|
-
logger.
|
|
146
|
+
logger.warning("_tft_clean_callback err_strategy: {}".format(err_strategy))
|
|
122
147
|
if err_strategy == "RS_UCE_HIGHLEVEL":
|
|
123
148
|
ret = 0
|
|
124
149
|
elif err_strategy == "RS_UCE_LOWLEVEL":
|
|
@@ -126,37 +151,49 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
126
151
|
else:
|
|
127
152
|
ret = 1
|
|
128
153
|
clean_tdt_channel()
|
|
129
|
-
logger.
|
|
154
|
+
logger.warning("Enter _tft_clean_callback resume_hccl_comm")
|
|
130
155
|
CollectiveManager.get_instance().resume_hccl_comm()
|
|
131
|
-
logger.
|
|
156
|
+
logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
132
157
|
return ret
|
|
133
158
|
|
|
134
159
|
|
|
135
160
|
def _tft_stop_callback(args, cb_ctx):
|
|
136
161
|
""" Callback used for TFT stop function."""
|
|
137
|
-
logger.
|
|
162
|
+
logger.warning("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
|
|
138
163
|
_stop_device(cb_ctx.device_id)
|
|
139
|
-
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()):
|
|
164
|
+
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
|
|
140
165
|
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
141
166
|
cb_ctx.is_uce_rank = False
|
|
167
|
+
if cb_ctx.tft.tft_get_repair_type() == "recover":
|
|
168
|
+
logger.warning(f"Reset limit step")
|
|
169
|
+
cb_ctx.tft.tft_reset_limit_step()
|
|
142
170
|
logger.info("Finish _tft_stop_callback")
|
|
143
171
|
|
|
144
172
|
|
|
145
|
-
|
|
173
|
+
def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
|
|
174
|
+
"""Callback used for TFT Rebuild Group function."""
|
|
175
|
+
logger.warning(f"Enter _tft_rebuild_sub_groups, device id: ".format(ctx.device_id))
|
|
176
|
+
_finalize_comm()
|
|
177
|
+
_rebuild_world_group()
|
|
178
|
+
_rebuild_sub_group()
|
|
179
|
+
_set_recovery_context(is_arf=True)
|
|
180
|
+
logger.warning("Enter _tft_rebuild_sub_groups ok ")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class TrainFaultTolerance(Callback):
|
|
146
184
|
"""
|
|
147
185
|
This callback is used to enable the TFT feature
|
|
148
|
-
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
|
|
149
|
-
|
|
186
|
+
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_
|
|
187
|
+
and will execute TFT operations during training process, such as TFT init, report and exception handle.
|
|
150
188
|
|
|
151
189
|
Note:
|
|
152
190
|
Required for Ascend graph mode only. And sink size must be less than or equal to 1.
|
|
153
191
|
|
|
154
192
|
Args:
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
|
|
193
|
+
ckpt_save_path (str): Checkpoint save directory when failure occurs. When saved,
|
|
194
|
+
a new directory named 'ttp_saved_checkpoints-step_{cur_step_num}'
|
|
195
|
+
is created in that directory. Default: ``None``.
|
|
196
|
+
kwargs (dict): Other dictionary type parameters.
|
|
160
197
|
|
|
161
198
|
Raises:
|
|
162
199
|
Exception: TFT init failed.
|
|
@@ -168,7 +205,7 @@ class TFTRegister(Callback):
|
|
|
168
205
|
|
|
169
206
|
It's recommended to use the msrun startup method.
|
|
170
207
|
Please see the `msrun start up
|
|
171
|
-
<https://www.mindspore.cn/
|
|
208
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
172
209
|
for more details.
|
|
173
210
|
|
|
174
211
|
This example should be run with 4 devices.
|
|
@@ -181,7 +218,7 @@ class TFTRegister(Callback):
|
|
|
181
218
|
>>> from mindspore import nn, ops, Parameter, train
|
|
182
219
|
>>> from mindspore.communication import init, get_rank
|
|
183
220
|
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
184
|
-
>>> from mindspore.train import Model,
|
|
221
|
+
>>> from mindspore.train import Model, TrainFaultTolerance
|
|
185
222
|
>>> from mindspore import dataset as ds
|
|
186
223
|
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
|
|
187
224
|
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
|
|
@@ -252,43 +289,52 @@ class TFTRegister(Callback):
|
|
|
252
289
|
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
|
|
253
290
|
>>> loss_fn = nn.CrossEntropyLoss()
|
|
254
291
|
>>>
|
|
255
|
-
>>> net_with_loss = nn.
|
|
292
|
+
>>> net_with_loss = nn.Pipeline(nn.WithLossCell(net, loss_fn), 4)
|
|
256
293
|
>>> net_with_loss.set_train()
|
|
257
294
|
>>> model = Model(net_with_loss, optimizer=optimizer_wrapper)
|
|
258
|
-
>>> tft_cb =
|
|
295
|
+
>>> tft_cb = TrainFaultTolerance()
|
|
259
296
|
>>> loss_cb = train.LossMonitor(1)
|
|
260
297
|
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
|
|
261
298
|
"""
|
|
262
299
|
|
|
263
|
-
def __init__(self,
|
|
264
|
-
super(
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
if
|
|
268
|
-
raise ValueError("
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
# let it raise errors if not install mindio_tft package
|
|
275
|
-
from mindio_ttp import framework_ttp as tft
|
|
276
|
-
self.tft = tft
|
|
277
|
-
self.global_step = 0
|
|
278
|
-
Validator.check_non_negative_int(ctrl_port)
|
|
300
|
+
def __init__(self, ckpt_save_path=None, **kwargs):
|
|
301
|
+
super(TrainFaultTolerance, self).__init__()
|
|
302
|
+
self.save_cb = kwargs.get("ckpt_save_fn", None)
|
|
303
|
+
self.ckpt_save_path = ckpt_save_path
|
|
304
|
+
if self.save_cb is None and self.ckpt_save_path is None:
|
|
305
|
+
raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
|
|
306
|
+
self.tft = _tft_handler.get_tft()
|
|
307
|
+
self._check_init()
|
|
308
|
+
self.global_step = None
|
|
309
|
+
self.learning_rate = None
|
|
279
310
|
self.has_init_replica = False
|
|
280
311
|
self.is_uce_rank = False
|
|
281
|
-
self._controller_ip = ctrl_ip
|
|
282
|
-
self._controller_rank_id = ctrl_rank_id
|
|
283
|
-
self._controller_port = ctrl_port
|
|
284
312
|
self.cb_params = None
|
|
313
|
+
self.initial_step = kwargs.get("initial_step", 0)
|
|
285
314
|
self.device_id = context.get_context("device_id")
|
|
286
|
-
self._init_tft()
|
|
287
|
-
self.ckpt_save_path = ckpt_save_path
|
|
288
315
|
self.assign = mindspore.ops.Assign()
|
|
289
316
|
self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
|
|
290
317
|
self.s1 = mindspore.hal.Stream()
|
|
318
|
+
self.cur_step_num = 0
|
|
319
|
+
self.cur_epoch_num = 0
|
|
291
320
|
_tft_sem_enable()
|
|
321
|
+
self._tft_register()
|
|
322
|
+
|
|
323
|
+
def _check_init(self):
|
|
324
|
+
"""Check if the mindio-ttp had inited"""
|
|
325
|
+
if self.tft is None:
|
|
326
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
327
|
+
if "ARF:1" in tft_env:
|
|
328
|
+
raise ValueError("Must init by _tft_handler.init(config=params) if use ARF.")
|
|
329
|
+
logger.warning(f"TFT handle not init, try to init")
|
|
330
|
+
_tft_handler.init(config=None)
|
|
331
|
+
self.tft = _tft_handler.get_tft()
|
|
332
|
+
logger.warning(f"TFT handle init ok.")
|
|
333
|
+
mode = context.get_context("mode")
|
|
334
|
+
device_target = context.get_context("device_target")
|
|
335
|
+
if device_target != "Ascend" or mode != context.GRAPH_MODE:
|
|
336
|
+
raise ValueError(f"MindIO adataper only support on Ascend device with GRAPH Mode!"
|
|
337
|
+
f"device:{device_target}, run mode: {mode}")
|
|
292
338
|
|
|
293
339
|
def _is_params_consistent(self):
|
|
294
340
|
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
@@ -300,7 +346,7 @@ class TFTRegister(Callback):
|
|
|
300
346
|
return False
|
|
301
347
|
|
|
302
348
|
def _set_tft_optimizer_replica(self, run_context):
|
|
303
|
-
"""
|
|
349
|
+
""" Set Mindio TFT optimizer replica info, used internal. """
|
|
304
350
|
cur_rank = get_rank()
|
|
305
351
|
cb_params = run_context.original_args()
|
|
306
352
|
train_network = cb_params.train_network
|
|
@@ -322,33 +368,49 @@ class TFTRegister(Callback):
|
|
|
322
368
|
]
|
|
323
369
|
self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
|
|
324
370
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
371
|
+
@classmethod
|
|
372
|
+
def get_optimizer_wrapper(cls, origin_opt_cls):
|
|
373
|
+
"""
|
|
374
|
+
Optimizer wrapper func when using tft.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
origin_opt_cls (Class): origin optimizer class.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
class TFTOptSubCls(origin_opt_cls):
|
|
381
|
+
"""
|
|
382
|
+
Optimizer wrapper class when using tft.
|
|
383
|
+
"""
|
|
384
|
+
|
|
385
|
+
def __init__(self, *args, **kwargs):
|
|
386
|
+
super(TFTOptSubCls, self).__init__(*args, **kwargs)
|
|
387
|
+
self.report = TensorReport()
|
|
388
|
+
self.report_end = TensorReport()
|
|
389
|
+
self.report_end.add_prim_attr("side_effect_mem", True).add_prim_attr("optimizer_end", True)
|
|
390
|
+
self.depend = ops.Depend()
|
|
391
|
+
self.allreduce_sum = ops.AllReduce()
|
|
392
|
+
self.allreduce_sum.add_prim_attr("tft_report_before", True)
|
|
393
|
+
self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
|
|
394
|
+
|
|
395
|
+
def construct(self, gradients, **kwargs):
|
|
396
|
+
tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
|
|
397
|
+
self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
|
|
398
|
+
grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
|
|
399
|
+
opt_ret = super(TFTOptSubCls, self).construct(grads, **kwargs)
|
|
400
|
+
self.report_end("tft_report", self.tft_g_one_flag)
|
|
401
|
+
return opt_ret
|
|
402
|
+
|
|
403
|
+
return TFTOptSubCls
|
|
404
|
+
|
|
405
|
+
def _tft_register(self):
|
|
406
|
+
"""Register callback functions."""
|
|
328
407
|
self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
|
|
329
408
|
self.tft.tft_register_rename_handler(_rename_save_result, self)
|
|
330
409
|
self.tft.tft_register_exit_handler(_tft_exit_cb, self)
|
|
331
410
|
self.tft.tft_register_stop_handler(_tft_stop_callback, self)
|
|
332
411
|
self.tft.tft_register_clean_handler(_tft_clean_callback, self)
|
|
333
412
|
self.tft.tft_register_repair_handler(_tft_repair_callback, self)
|
|
334
|
-
|
|
335
|
-
world_size = _get_device_num()
|
|
336
|
-
cur_rank = get_rank()
|
|
337
|
-
enable_local_copy = False
|
|
338
|
-
enable_arf = False
|
|
339
|
-
enable_tls = False
|
|
340
|
-
tls_key_dir = ""
|
|
341
|
-
|
|
342
|
-
if cur_rank == self._controller_rank_id:
|
|
343
|
-
logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
|
|
344
|
-
self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf)
|
|
345
|
-
self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
|
|
346
|
-
logger.info("Finish start tft controller.")
|
|
347
|
-
|
|
348
|
-
logger.info("Begin to start tft processor.")
|
|
349
|
-
self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
|
|
350
|
-
self.tft.tft_start_processor(self._controller_ip, self._controller_port)
|
|
351
|
-
logger.info("Finished start tft processor.")
|
|
413
|
+
self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
|
|
352
414
|
|
|
353
415
|
def _reset_acc_grads(self):
|
|
354
416
|
accu_grad_params = map(lambda e: e[1],
|
|
@@ -360,7 +422,7 @@ class TFTRegister(Callback):
|
|
|
360
422
|
|
|
361
423
|
def on_train_step_end(self, run_context):
|
|
362
424
|
"""
|
|
363
|
-
|
|
425
|
+
Report status to MindIO TFT after every step finished.
|
|
364
426
|
|
|
365
427
|
Args:
|
|
366
428
|
run_context (RunContext): Context of the train running. Refer to
|
|
@@ -371,17 +433,27 @@ class TFTRegister(Callback):
|
|
|
371
433
|
self._set_tft_optimizer_replica(run_context)
|
|
372
434
|
cb_params = run_context.original_args()
|
|
373
435
|
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
|
|
436
|
+
self.cur_step_num = cb_params.cur_step_num
|
|
437
|
+
self.cur_epoch_num = cb_params.cur_epoch_num
|
|
374
438
|
if cb_params.optimizer is not None:
|
|
375
|
-
self.global_step =
|
|
439
|
+
self.global_step = cb_params.optimizer.global_step.clone()
|
|
376
440
|
self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
|
|
377
|
-
|
|
378
|
-
self.global_step =
|
|
441
|
+
elif hasattr(cb_params.network, 'optimizer') and cb_params.network.optimizer is not None:
|
|
442
|
+
self.global_step = cb_params.network.optimizer.global_step.clone()
|
|
379
443
|
self.assign(cb_params.network.optimizer.tft_g_one_flag, self.g_one)
|
|
380
|
-
|
|
444
|
+
else:
|
|
445
|
+
raise ValueError("TFT feature need optimizer or network's optimizer!")
|
|
446
|
+
self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
|
|
381
447
|
logger.info("END Set optimizer finish step status to TFT.")
|
|
382
448
|
|
|
383
|
-
|
|
384
449
|
def on_train_begin(self, run_context):
|
|
450
|
+
"""
|
|
451
|
+
Register train params to MindIO TFT on train beginning.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
455
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
456
|
+
"""
|
|
385
457
|
cb_params = run_context.original_args()
|
|
386
458
|
sink_size = cb_params.get("sink_size", 0)
|
|
387
459
|
if sink_size > 1:
|
|
@@ -391,7 +463,11 @@ class TFTRegister(Callback):
|
|
|
391
463
|
self.cb_params = cb_params
|
|
392
464
|
|
|
393
465
|
def end(self, run_context):
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
466
|
+
"""
|
|
467
|
+
Unregister MindIO TFT on train end.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
471
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
472
|
+
"""
|
|
473
|
+
_tft_handler.unregister_tft()
|
mindspore/train/data_sink.py
CHANGED
|
@@ -98,6 +98,29 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
|
|
|
98
98
|
return next_op, (key, dataset_shapes, dataset_types)
|
|
99
99
|
|
|
100
100
|
|
|
101
|
+
def _get_jit_func(sink_fun, jit_config):
|
|
102
|
+
"""
|
|
103
|
+
Get the jit function.
|
|
104
|
+
"""
|
|
105
|
+
jit_config_dict = jit_config.jit_config_dict
|
|
106
|
+
jit_level = jit_config_dict['jit_level']
|
|
107
|
+
if jit_level == "":
|
|
108
|
+
jit_level = "O0"
|
|
109
|
+
backend = ""
|
|
110
|
+
if jit_level == "O2":
|
|
111
|
+
jit_level = "O0"
|
|
112
|
+
backend = "GE"
|
|
113
|
+
if "backend" in jit_config_dict:
|
|
114
|
+
backend = jit_config_dict["backend"]
|
|
115
|
+
fullgraph = False
|
|
116
|
+
if jit_config_dict['jit_syntax_level'] == "STRICT":
|
|
117
|
+
fullgraph = True
|
|
118
|
+
exc_mode = jit_config_dict['exc_mode']
|
|
119
|
+
infer_boost = jit_config_dict['infer_boost']
|
|
120
|
+
return jit(sink_fun, jit_level=jit_level, backend=backend, fullgraph=fullgraph, exc_mode=exc_mode,
|
|
121
|
+
infer_boost=infer_boost)
|
|
122
|
+
|
|
123
|
+
|
|
101
124
|
def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
102
125
|
"""
|
|
103
126
|
get the sink function.
|
|
@@ -107,7 +130,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
|
107
130
|
if jit_config is None:
|
|
108
131
|
dst_sink_fun = sink_fun
|
|
109
132
|
else:
|
|
110
|
-
dst_sink_fun =
|
|
133
|
+
dst_sink_fun = _get_jit_func(sink_fun, jit_config)
|
|
111
134
|
dataset.__sink_fun__ = dst_sink_fun
|
|
112
135
|
|
|
113
136
|
return dataset.__sink_fun__
|
|
@@ -119,7 +142,7 @@ def _get_sink_fun(sink_fun, key_info, is_info_queue, dataset, jit_config):
|
|
|
119
142
|
if jit_config is None:
|
|
120
143
|
dst_sink_fun = sink_fun
|
|
121
144
|
else:
|
|
122
|
-
dst_sink_fun =
|
|
145
|
+
dst_sink_fun = _get_jit_func(sink_fun, jit_config)
|
|
123
146
|
dataset.__sink_aux__.sink_funcs[key] = dst_sink_fun
|
|
124
147
|
|
|
125
148
|
return dst_sink_fun
|
|
@@ -214,8 +214,7 @@ def _get_dataset_aux(dataset):
|
|
|
214
214
|
|
|
215
215
|
def connect_network_with_dataset(network, dataset_helper):
|
|
216
216
|
"""
|
|
217
|
-
Connect the `network` with dataset in `dataset_helper`. Only supported in
|
|
218
|
-
<https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
|
|
217
|
+
Connect the `network` with dataset in `dataset_helper`. Only supported in sink mode,
|
|
219
218
|
(dataset_sink_mode=True).
|
|
220
219
|
|
|
221
220
|
Args:
|
|
@@ -335,11 +334,11 @@ class DatasetHelper:
|
|
|
335
334
|
dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
|
|
336
335
|
dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
|
|
337
336
|
Default: ``True``.
|
|
338
|
-
sink_size (int): Control the amount of data in each sink.
|
|
337
|
+
sink_size (int): Control the amount of data in each sink. Must be -1 or positive.
|
|
339
338
|
If sink_size=-1, sink the complete dataset for each epoch.
|
|
340
339
|
If sink_size>0, sink sink_size data for each epoch.
|
|
341
|
-
Default:
|
|
342
|
-
epoch_num (int): The number of passes of the entire dataset to be sent. Default: 1
|
|
340
|
+
Default: ``-1``.
|
|
341
|
+
epoch_num (int): The number of passes of the entire dataset to be sent. Default: ``1``.
|
|
343
342
|
|
|
344
343
|
Examples:
|
|
345
344
|
>>> import numpy as np
|
|
@@ -51,9 +51,10 @@ class FixedLossScaleManager(LossScaleManager):
|
|
|
51
51
|
inherits from :class:`mindspore.amp.LossScaleManager`.
|
|
52
52
|
|
|
53
53
|
Args:
|
|
54
|
-
loss_scale (float): Magnification factor of gradients.
|
|
54
|
+
loss_scale (float, optional): Magnification factor of gradients.
|
|
55
|
+
Note that if `drop_overflow_update` is set to ``False`` ,
|
|
55
56
|
the value of `loss_scale` in optimizer should be set to the same as here. Default: ``128.0`` .
|
|
56
|
-
drop_overflow_update (bool): Whether to execute optimizer if there is an overflow.
|
|
57
|
+
drop_overflow_update (bool, optional): Whether to execute optimizer if there is an overflow.
|
|
57
58
|
If ``True`` , the optimizer will
|
|
58
59
|
not executed when overflow occurs. Default: ``True`` .
|
|
59
60
|
|
|
@@ -110,8 +111,8 @@ class FixedLossScaleManager(LossScaleManager):
|
|
|
110
111
|
|
|
111
112
|
Returns:
|
|
112
113
|
None or :class:`mindspore.nn.FixedLossScaleUpdateCell`. Instance of
|
|
113
|
-
:class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is True
|
|
114
|
-
`drop_overflow_update` is False
|
|
114
|
+
:class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is ``True``. None when
|
|
115
|
+
`drop_overflow_update` is ``False``.
|
|
115
116
|
"""
|
|
116
117
|
if not self._drop_overflow_update:
|
|
117
118
|
return None
|
|
@@ -124,9 +125,9 @@ class DynamicLossScaleManager(LossScaleManager):
|
|
|
124
125
|
adjusted, inherits from :class:`mindspore.amp.LossScaleManager`.
|
|
125
126
|
|
|
126
127
|
Args:
|
|
127
|
-
init_loss_scale (float): Initialize loss scale. Default: ``2 ** 24`` .
|
|
128
|
-
scale_factor (int): Coefficient of increase and decrease. Default: ``2`` .
|
|
129
|
-
scale_window (int): Maximum continuous normal steps when there is no overflow. Default: ``2000`` .
|
|
128
|
+
init_loss_scale (float, optional): Initialize loss scale. Default: ``2 ** 24`` .
|
|
129
|
+
scale_factor (int, optional): Coefficient of increase and decrease. Default: ``2`` .
|
|
130
|
+
scale_window (int, optional): Maximum continuous normal steps when there is no overflow. Default: ``2000`` .
|
|
130
131
|
|
|
131
132
|
Supported Platforms:
|
|
132
133
|
``Ascend`` ``GPU``
|
|
@@ -45,11 +45,11 @@ class Accuracy(EvaluationBase):
|
|
|
45
45
|
>>> from mindspore import Tensor
|
|
46
46
|
>>> from mindspore.train import Accuracy
|
|
47
47
|
>>>
|
|
48
|
-
>>>
|
|
49
|
-
>>>
|
|
48
|
+
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mindspore.float32)
|
|
49
|
+
>>> y_true = Tensor(np.array([1, 0, 1]), mindspore.float32)
|
|
50
50
|
>>> metric = Accuracy('classification')
|
|
51
51
|
>>> metric.clear()
|
|
52
|
-
>>> metric.update(
|
|
52
|
+
>>> metric.update(y_pred, y_true)
|
|
53
53
|
>>> accuracy = metric.eval()
|
|
54
54
|
>>> print(accuracy)
|
|
55
55
|
0.6666666666666666
|
|
@@ -23,15 +23,15 @@ from mindspore.train.metrics.metric import Metric, rearrange_inputs
|
|
|
23
23
|
|
|
24
24
|
class ConfusionMatrix(Metric):
|
|
25
25
|
"""
|
|
26
|
-
Computes the
|
|
26
|
+
Computes the Confusion Matrix, which is commonly used to evaluate the performance of classification models,
|
|
27
27
|
including binary classification and multiple classification.
|
|
28
28
|
|
|
29
|
-
If you only need
|
|
29
|
+
If you only need Confusion Matrix, use this class. If you want to calculate other metrics, such as 'PPV',
|
|
30
30
|
'TPR', 'TNR', etc., use class :class:`mindspore.train.ConfusionMatrixMetric` .
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
33
|
num_classes (int): Number of classes in the dataset.
|
|
34
|
-
normalize (str): Normalization mode for
|
|
34
|
+
normalize (str): Normalization mode for Confusion Matrix. Default: ``"no_norm"`` . Choose from:
|
|
35
35
|
|
|
36
36
|
- ``"no_norm"`` : No Normalization is used. Default: ``None``.
|
|
37
37
|
- ``"target"`` : Normalization based on target value.
|
|
@@ -78,7 +78,7 @@ class ConfusionMatrix(Metric):
|
|
|
78
78
|
@rearrange_inputs
|
|
79
79
|
def update(self, *inputs):
|
|
80
80
|
"""
|
|
81
|
-
Update state with y_pred and y
|
|
81
|
+
Update state with `y_pred` and `y`.
|
|
82
82
|
|
|
83
83
|
Args:
|
|
84
84
|
inputs(tuple): Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, list or numpy.ndarray.
|
|
@@ -88,7 +88,7 @@ class ConfusionMatrix(Metric):
|
|
|
88
88
|
|
|
89
89
|
Raises:
|
|
90
90
|
ValueError: If the number of inputs is not 2.
|
|
91
|
-
ValueError: If the
|
|
91
|
+
ValueError: If the dims of `y_pred` and `y` are not equal.
|
|
92
92
|
"""
|
|
93
93
|
if len(inputs) != 2:
|
|
94
94
|
raise ValueError("For 'ConfusionMatrix.update', it needs 2 inputs (predicted value, true value), "
|
|
@@ -151,8 +151,8 @@ class ConfusionMatrixMetric(Metric):
|
|
|
151
151
|
batch, class channel and iteration are collected. All metrics supported by the interface are listed in comments
|
|
152
152
|
of `metric_name`.
|
|
153
153
|
|
|
154
|
-
If you want to calculate metrics related to confusion matrix, such as 'PPV', 'TPR', 'TNR', use this class.
|
|
155
|
-
If you only want to calculate confusion matrix, please use :class:`mindspore.train.ConfusionMatrix` .
|
|
154
|
+
- If you want to calculate metrics related to confusion matrix, such as 'PPV', 'TPR', 'TNR', use this class.
|
|
155
|
+
- If you only want to calculate confusion matrix, please use :class:`mindspore.train.ConfusionMatrix` .
|
|
156
156
|
|
|
157
157
|
Args:
|
|
158
158
|
skip_channel (bool): Whether to skip the measurement calculation on the first channel of the predicted output.
|
|
@@ -163,9 +163,9 @@ class ConfusionMatrixMetric(Metric):
|
|
|
163
163
|
"threat score", "accuracy", "balanced accuracy", "f1 score",
|
|
164
164
|
"matthews correlation coefficient", "fowlkes mallows index", "informedness", "markedness"].
|
|
165
165
|
Default: ``"sensitivity"`` .
|
|
166
|
-
calculation_method (bool): If
|
|
166
|
+
calculation_method (bool): If ``True``, the measurement for each sample will be calculated first.
|
|
167
167
|
If not, the confusion matrix of all samples will be accumulated first.
|
|
168
|
-
As for classification task, 'calculation_method' should be False
|
|
168
|
+
As for classification task, 'calculation_method' should be ``False``. Default: ``False`` .
|
|
169
169
|
decrease (str): The reduction method on data batch. `decrease` takes effect only when calculation_method
|
|
170
170
|
is True. Default: ``"mean"`` . Choose from:
|
|
171
171
|
["none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"].
|