mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/train/amp.py
CHANGED
|
@@ -14,6 +14,11 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Auto mixed precision."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
|
+
import inspect
|
|
18
|
+
import types
|
|
19
|
+
from typing import Any
|
|
20
|
+
import functools
|
|
21
|
+
import collections
|
|
17
22
|
|
|
18
23
|
import mindspore as ms
|
|
19
24
|
from mindspore import nn
|
|
@@ -27,8 +32,9 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScal
|
|
|
27
32
|
from mindspore import boost, context
|
|
28
33
|
from mindspore.ops import operations as P
|
|
29
34
|
from mindspore.ops import Primitive
|
|
35
|
+
from mindspore.ops import auto_generate as gen
|
|
30
36
|
from mindspore import log as logger
|
|
31
|
-
|
|
37
|
+
from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, create_amp_strategy, AmpLevel
|
|
32
38
|
|
|
33
39
|
AMP_WHITE_LIST = [
|
|
34
40
|
nn.Conv1d,
|
|
@@ -50,19 +56,81 @@ AMP_WHITE_LIST = [
|
|
|
50
56
|
P.BatchMatMul,
|
|
51
57
|
P.PReLU,
|
|
52
58
|
P.ReLU,
|
|
53
|
-
P.Ger
|
|
59
|
+
P.Ger,
|
|
54
60
|
]
|
|
55
61
|
|
|
56
|
-
|
|
57
62
|
AMP_BLACK_LIST = [
|
|
58
63
|
nn.BatchNorm1d,
|
|
59
64
|
nn.BatchNorm2d,
|
|
60
65
|
nn.BatchNorm3d,
|
|
61
|
-
nn.LayerNorm
|
|
66
|
+
nn.LayerNorm,
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
AMP_AUTO_WHITE_LIST = [
|
|
70
|
+
P.Conv2D,
|
|
71
|
+
P.Conv3D,
|
|
72
|
+
P.Conv2DTranspose,
|
|
73
|
+
P.Conv3DTranspose,
|
|
74
|
+
gen.Convolution,
|
|
75
|
+
P.MatMul,
|
|
76
|
+
gen.MatMulExt,
|
|
77
|
+
P.BatchMatMul,
|
|
78
|
+
gen.BatchMatMulExt,
|
|
79
|
+
gen.PReLU,
|
|
80
|
+
P.Einsum,
|
|
81
|
+
gen.Dense,
|
|
82
|
+
gen.Addmm,
|
|
62
83
|
]
|
|
63
84
|
|
|
85
|
+
AMP_AUTO_BLACK_LIST = [
|
|
86
|
+
gen.Pow,
|
|
87
|
+
gen.ACos,
|
|
88
|
+
gen.Asin,
|
|
89
|
+
gen.Cosh,
|
|
90
|
+
P.Erfinv,
|
|
91
|
+
P.Exp,
|
|
92
|
+
P.Expm1,
|
|
93
|
+
P.Log,
|
|
94
|
+
P.Log1p,
|
|
95
|
+
P.Reciprocal,
|
|
96
|
+
P.Rsqrt,
|
|
97
|
+
P.Sinh,
|
|
98
|
+
P.Tan,
|
|
99
|
+
P.Softplus,
|
|
100
|
+
gen.SoftplusExt,
|
|
101
|
+
P.LayerNorm,
|
|
102
|
+
gen.LayerNormExt,
|
|
103
|
+
P.BatchNorm,
|
|
104
|
+
gen.GroupNorm,
|
|
105
|
+
P.KLDivLoss,
|
|
106
|
+
P.SmoothL1Loss,
|
|
107
|
+
P.MultilabelMarginLoss,
|
|
108
|
+
P.SoftMarginLoss,
|
|
109
|
+
P.TripletMarginLoss,
|
|
110
|
+
P.MultiMarginLoss,
|
|
111
|
+
P.BCEWithLogitsLoss,
|
|
112
|
+
P.Pdist,
|
|
113
|
+
P.Cdist,
|
|
114
|
+
P.Renorm,
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
# Indicates which inputs of primitives need to be converted
|
|
118
|
+
AMP_PRIM_ARG_TABLE = collections.defaultdict(list, {})
|
|
119
|
+
|
|
120
|
+
# Primitives in inner amp black list will not be converted in O2/O3
|
|
121
|
+
_INNER_AMP_BLACK_LIST = []
|
|
122
|
+
|
|
64
123
|
MS_AMP_BY_REWRITE = False
|
|
65
|
-
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def amp_cast(value, dtype):
|
|
127
|
+
"""This function is used to insert cast operators for tensors during auto mixed precision."""
|
|
128
|
+
if isinstance(value, ms.Tensor) and value.dtype in mstype.float_type:
|
|
129
|
+
return P.Cast()(value, dtype)
|
|
130
|
+
return value
|
|
131
|
+
|
|
132
|
+
_amp_cast_op = amp_cast
|
|
133
|
+
|
|
66
134
|
|
|
67
135
|
class _OutputTo16(nn.Cell):
|
|
68
136
|
"""Wrap cell for amp. Cast network output back to float16."""
|
|
@@ -88,278 +156,185 @@ class _OutputTo32(nn.Cell):
|
|
|
88
156
|
return F.mixed_precision_cast(mstype.float32, out)
|
|
89
157
|
|
|
90
158
|
|
|
91
|
-
|
|
92
|
-
def _allow_mix_precision(node, allowed_list, dtype) -> bool:
|
|
159
|
+
def _operator_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
|
|
93
160
|
"""
|
|
94
|
-
Check whether current node need
|
|
95
|
-
1) Type of node is
|
|
96
|
-
2)
|
|
97
|
-
3)
|
|
161
|
+
Check whether current node is a operator that need to be casted. Follow conditions need to be satisfied:
|
|
162
|
+
1) Type of node is CallPrimitive and type of instance is Primitive
|
|
163
|
+
2) Type of instance is not P.Cast
|
|
164
|
+
3) force_cast is True, which means one of upper layer cells is under casting
|
|
165
|
+
4) white_list exist and type of node is in white_list
|
|
166
|
+
5) black_list exist and type of node is in not black_list
|
|
98
167
|
"""
|
|
99
|
-
|
|
100
|
-
if node_inst in allowed_list:
|
|
101
|
-
return True
|
|
102
|
-
if node.get_targets() is None:
|
|
168
|
+
if node.get_node_type() != ms.rewrite.NodeType.CallPrimitive:
|
|
103
169
|
return False
|
|
104
|
-
if not
|
|
170
|
+
if not inspect.isclass(node.get_instance_type()):
|
|
105
171
|
return False
|
|
106
|
-
if
|
|
172
|
+
if not issubclass(node.get_instance_type(), Primitive):
|
|
107
173
|
return False
|
|
108
|
-
if issubclass(node.get_instance_type(),
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
174
|
+
if issubclass(node.get_instance_type(), P.Cast):
|
|
175
|
+
return False
|
|
176
|
+
if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
|
|
177
|
+
return False
|
|
178
|
+
if force_cast:
|
|
179
|
+
return True
|
|
180
|
+
if white_list is not None and node.get_instance_type() in white_list:
|
|
181
|
+
return True
|
|
182
|
+
if black_list is not None and node.get_instance_type() not in black_list:
|
|
183
|
+
return True
|
|
184
|
+
return False
|
|
117
185
|
|
|
118
186
|
|
|
119
|
-
def
|
|
120
|
-
"""
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
if issubclass(node.get_instance_type(), Primitive):
|
|
126
|
-
for idx, arg in enumerate(node.get_args()):
|
|
127
|
-
position = stree.before(node)
|
|
128
|
-
new_node = _amp_cast_op()
|
|
129
|
-
cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, ""])
|
|
130
|
-
arg_provider = node.get_handler().get_arg_providers()[idx]
|
|
131
|
-
if arg_provider and len(arg_provider[0].get_target_users(arg_provider[1])) > 1:
|
|
132
|
-
cast_targets = [stree.unique_name(str(arg))]
|
|
133
|
-
else:
|
|
134
|
-
cast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
|
|
135
|
-
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
136
|
-
targets=cast_targets,
|
|
137
|
-
args=cast_args,
|
|
138
|
-
name='incast_{}{}'.format(node.get_name(), idx))
|
|
139
|
-
stree.insert(position, new_cast_node)
|
|
140
|
-
node.set_arg_by_node(idx, new_cast_node)
|
|
141
|
-
# insert cast fp16/bf16 before the Cell operators
|
|
142
|
-
elif issubclass(node.get_instance_type(), nn.Cell):
|
|
143
|
-
node.get_instance().to_float(dtype)
|
|
144
|
-
# ignore if subclass is not one of (Primitive, nn.Cell)
|
|
145
|
-
else:
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
# insert cast float32 after the operators
|
|
149
|
-
position = stree.after(node)
|
|
150
|
-
new_node = _amp_cast_op()
|
|
151
|
-
cast_args = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
|
152
|
-
"mindspore.float32"])
|
|
153
|
-
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
154
|
-
targets=[node.get_targets()[0]],
|
|
155
|
-
args=cast_args,
|
|
156
|
-
name='outcast_{}'.format(node.get_name()))
|
|
157
|
-
# insert node & unique names
|
|
158
|
-
stree.insert(position, new_cast_node)
|
|
159
|
-
# update argument names
|
|
160
|
-
for user in node.get_users():
|
|
161
|
-
if user.get_name() == new_cast_node.get_name():
|
|
162
|
-
continue
|
|
163
|
-
for idx, arg in enumerate(user.get_args()):
|
|
164
|
-
if arg == node.get_targets()[0]:
|
|
165
|
-
user.set_arg_by_node(idx, new_cast_node)
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def _insert_cast_operator_white_list(stree, white_list, dtype):
|
|
169
|
-
"""insert cast for operators in white_list."""
|
|
170
|
-
allowed_list = []
|
|
171
|
-
# Ignore if net called ".to_float(dtype)"
|
|
172
|
-
net = stree.get_handler().get_origin_network()
|
|
173
|
-
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
174
|
-
if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
|
|
175
|
-
return
|
|
176
|
-
node_list = []
|
|
177
|
-
node_list.extend(list(stree.nodes()))
|
|
178
|
-
while node_list:
|
|
179
|
-
node = node_list.pop()
|
|
180
|
-
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
181
|
-
if MS_AMP_BY_REWRITE:
|
|
182
|
-
_insert_cast_for_cell_container(node, dtype, allowed_list, white_list=white_list)
|
|
183
|
-
for n in node.get_handler().node_list:
|
|
184
|
-
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
185
|
-
_insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
|
|
186
|
-
white_list, dtype)
|
|
187
|
-
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
188
|
-
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
189
|
-
_insert_cast_operator_white_list(substree, white_list, dtype)
|
|
190
|
-
elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
|
|
191
|
-
if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
|
|
192
|
-
nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
|
|
193
|
-
node_list.extend(nodes)
|
|
194
|
-
elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list, dtype):
|
|
195
|
-
_insert_cast_operator_process(node, dtype)
|
|
187
|
+
def _precision_set_by_user(cell_inst: nn.Cell) -> bool:
|
|
188
|
+
"""Check whether cell precision is set by user."""
|
|
189
|
+
for flag in ["fp32", "fp16", "bf16"]:
|
|
190
|
+
if hasattr(cell_inst, flag) and getattr(cell_inst, flag):
|
|
191
|
+
return True
|
|
192
|
+
return False
|
|
196
193
|
|
|
197
194
|
|
|
198
|
-
def
|
|
195
|
+
def _net_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
|
|
199
196
|
"""
|
|
200
|
-
|
|
201
|
-
|
|
197
|
+
Check whether current node is type of tree whose network needs to be casted. Follow conditions need to
|
|
198
|
+
be satisfied:
|
|
199
|
+
1) Type of node is Tree and type of instance is Cell
|
|
200
|
+
2) Cell.to_float(xxx) is not set by user
|
|
201
|
+
3) force_cast is True, which means one of upper layer networks is under casting
|
|
202
|
+
4) white_list exist and type of node is in white_list
|
|
203
|
+
5) black_list exist and type of node is in not black_list
|
|
202
204
|
"""
|
|
205
|
+
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
|
206
|
+
return False
|
|
207
|
+
if not inspect.isclass(node.get_instance_type()):
|
|
208
|
+
return False
|
|
209
|
+
if not issubclass(node.get_instance_type(), nn.Cell):
|
|
210
|
+
return False
|
|
211
|
+
if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
|
|
212
|
+
return False
|
|
213
|
+
if _precision_set_by_user(node.get_instance()):
|
|
214
|
+
return False
|
|
215
|
+
if force_cast:
|
|
216
|
+
return True
|
|
217
|
+
if white_list is not None and node.get_instance_type() in white_list:
|
|
218
|
+
return True
|
|
219
|
+
if black_list is not None and node.get_instance_type() not in black_list:
|
|
220
|
+
return True
|
|
221
|
+
return False
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _insert_cast_for_operator(node, dtype):
|
|
225
|
+
"""insert cast pair for node."""
|
|
226
|
+
dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
|
|
227
|
+
stree = node.get_symbol_tree()
|
|
228
|
+
# insert cast fp16/bf16 for inputs of node
|
|
229
|
+
for idx, arg in enumerate(node.get_args()):
|
|
230
|
+
if arg.type != ms.rewrite.ValueType.NamingValue:
|
|
231
|
+
continue
|
|
232
|
+
incast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, "mindspore"])
|
|
233
|
+
arg_providers = node.get_arg_providers()
|
|
234
|
+
if not arg_providers or idx not in arg_providers or \
|
|
235
|
+
len(arg_providers[idx][0].get_target_users(arg_providers[idx][1])) > 1:
|
|
236
|
+
# create new target names when argument is used by other node
|
|
237
|
+
incast_targets = [stree.unique_name(f"{arg.value}_var")]
|
|
238
|
+
else:
|
|
239
|
+
incast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
|
|
240
|
+
incast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=incast_targets, args=incast_args)
|
|
241
|
+
stree.insert(stree.before(node), incast_node)
|
|
242
|
+
node.set_arg_by_node(idx, incast_node)
|
|
243
|
+
# insert cast fp32 for outputs of node
|
|
244
|
+
for _, target in enumerate(node.get_targets()):
|
|
245
|
+
if target.type != ms.rewrite.ValueType.NamingValue:
|
|
246
|
+
continue
|
|
247
|
+
outcast_args = ms.rewrite.ScopedValue.create_name_values([target.value, "float32"],
|
|
248
|
+
[target.scope, "mindspore"])
|
|
249
|
+
outcast_targets = ms.rewrite.ScopedValue.create_name_values([target.value], [target.scope])
|
|
250
|
+
outcast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=outcast_targets, args=outcast_args)
|
|
251
|
+
stree.insert(stree.after(node), outcast_node)
|
|
203
252
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
if (white_list is not None and current_node.get_instance_type() in white_list) or \
|
|
220
|
-
(black_list is not None and current_node.get_instance_type() not in black_list) and \
|
|
221
|
-
(_allow_mix_precision(current_node, allowed_list, dtype)):
|
|
222
|
-
cast_flag = True
|
|
223
|
-
current_node.get_instance().to_float(dtype)
|
|
224
|
-
elif cast_flag:
|
|
225
|
-
# cast next node back to float32
|
|
226
|
-
current_node.get_instance().to_float(mstype.float32)
|
|
227
|
-
cast_flag = False
|
|
228
|
-
if cast_flag and current_node:
|
|
229
|
-
# if last node in cell_container is casted to fp16/bf16, insert a cast node to cast value back to fp32
|
|
230
|
-
cast_node = ms.rewrite.Node.create_call_cell(cell=CastNet(mstype.float32),
|
|
231
|
-
args=[current_node.get_targets()[0]],
|
|
232
|
-
targets=[current_node.get_targets()[0]],
|
|
233
|
-
name=f"outcast_{cell_container.get_name()}")
|
|
234
|
-
stree.insert(stree.after(current_node), cast_node)
|
|
253
|
+
|
|
254
|
+
def _insert_cast_for_operators(stree, dtype, force_cast, *, white_list=None, black_list=None):
|
|
255
|
+
"""insert cast for operators not in black_list."""
|
|
256
|
+
# get all nodes of stree exclude nodes in subtree.
|
|
257
|
+
all_nodes = stree.all_nodes(False)
|
|
258
|
+
for node in all_nodes:
|
|
259
|
+
if not node.get_targets():
|
|
260
|
+
continue
|
|
261
|
+
if _operator_need_cast(node, force_cast, white_list, black_list):
|
|
262
|
+
_insert_cast_for_operator(node, dtype)
|
|
263
|
+
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
264
|
+
force_cast_ = force_cast or _net_need_cast(node, force_cast, white_list, black_list)
|
|
265
|
+
if not _precision_set_by_user(node.get_instance()):
|
|
266
|
+
subtree = node.get_sub_tree()
|
|
267
|
+
_insert_cast_for_operators(subtree, dtype, force_cast_, white_list=white_list, black_list=black_list)
|
|
235
268
|
|
|
236
269
|
|
|
237
270
|
def _need_removed_cast_pair(node, dtype):
|
|
238
271
|
"""check whether the cast pairs should be removed."""
|
|
239
|
-
dtype_str = "
|
|
240
|
-
cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "
|
|
272
|
+
dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
|
|
273
|
+
cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "float32"], ["mindspore", "mindspore"])
|
|
241
274
|
cast_dtype_f16 = cast_dtypes[0]
|
|
242
275
|
cast_dtype_f32 = cast_dtypes[1]
|
|
243
|
-
# current node should be
|
|
276
|
+
# current node should be cast fp32
|
|
244
277
|
if node.get_instance_type() != _amp_cast_op:
|
|
245
278
|
return False
|
|
246
279
|
node_cast_type = node.get_args()[1]
|
|
247
280
|
if node_cast_type != cast_dtype_f32:
|
|
248
281
|
return False
|
|
249
|
-
# all user nodes should be
|
|
282
|
+
# all user nodes should be cast fp16/bf16
|
|
250
283
|
if not node.get_users():
|
|
251
284
|
return False
|
|
252
285
|
all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
|
|
253
286
|
for user in node.get_users():
|
|
254
|
-
# If ControlFlow node(if
|
|
287
|
+
# If ControlFlow node(e.g. if, for, while) exists between current node and user node,
|
|
255
288
|
# cast pair should not be removed.
|
|
256
289
|
middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
|
|
257
290
|
if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
|
|
258
291
|
return False
|
|
259
|
-
if
|
|
260
|
-
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
261
|
-
if not (hasattr(user.get_instance(), to_float_flag) and getattr(user.get_instance(), to_float_flag)):
|
|
262
|
-
return False
|
|
263
|
-
elif user.get_instance_type() == _amp_cast_op:
|
|
264
|
-
user_cast_type = user.get_args()[1]
|
|
265
|
-
if user_cast_type != cast_dtype_f16:
|
|
266
|
-
return False
|
|
267
|
-
else:
|
|
292
|
+
if user.get_instance_type() != _amp_cast_op:
|
|
268
293
|
return False
|
|
294
|
+
user_cast_type = user.get_args()[1]
|
|
295
|
+
if user_cast_type != cast_dtype_f16:
|
|
296
|
+
return False
|
|
297
|
+
# cast pair detected, check next user
|
|
298
|
+
continue
|
|
269
299
|
return True
|
|
270
300
|
|
|
271
301
|
|
|
272
|
-
def _removed_cast_pair_process(cast_f32_node):
|
|
273
|
-
"""remove the duplicated cast operators."""
|
|
274
|
-
stree = cast_f32_node.get_symbol_tree()
|
|
275
|
-
cast_f32_users = cast_f32_node.get_users()
|
|
276
|
-
# remove cast f16 nodes
|
|
277
|
-
for user_node in cast_f32_users:
|
|
278
|
-
if user_node.get_instance_type() == _amp_cast_op:
|
|
279
|
-
cast_f16_node = user_node
|
|
280
|
-
# modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
|
|
281
|
-
for cast_f16_user in cast_f16_node.get_users():
|
|
282
|
-
for idx, arg in enumerate(cast_f16_user.get_args()):
|
|
283
|
-
if arg == cast_f16_node.get_targets()[0]:
|
|
284
|
-
cast_f16_user.set_arg(idx, cast_f32_node.get_args()[0])
|
|
285
|
-
stree.erase(cast_f16_node)
|
|
286
|
-
# update args of cell f16 nodes
|
|
287
|
-
elif isinstance(user_node.get_instance(), nn.Cell):
|
|
288
|
-
cell_f16_node = user_node
|
|
289
|
-
for idx, arg in enumerate(cell_f16_node.get_args()):
|
|
290
|
-
if arg == cast_f32_node.get_targets()[0]:
|
|
291
|
-
cell_f16_node.set_arg(idx, cast_f32_node.get_args()[0])
|
|
292
|
-
# remove the cast f32 node
|
|
293
|
-
stree.erase(cast_f32_node)
|
|
294
|
-
|
|
295
|
-
|
|
296
302
|
def _remove_duplicated_cast(stree, dtype):
|
|
297
303
|
"""remove the duplicated cast operators."""
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
while node_list:
|
|
301
|
-
node = node_list.pop()
|
|
302
|
-
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
303
|
-
for n in node.get_handler().node_list:
|
|
304
|
-
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
305
|
-
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), dtype)
|
|
306
|
-
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
307
|
-
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
308
|
-
_remove_duplicated_cast(substree, dtype)
|
|
309
|
-
elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
|
|
310
|
-
if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
|
|
311
|
-
nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
|
|
312
|
-
node_list.extend(nodes)
|
|
313
|
-
elif _need_removed_cast_pair(node, dtype):
|
|
314
|
-
_removed_cast_pair_process(node)
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
def _auto_white_list(network, white_list, dtype):
|
|
318
|
-
"""process the white list of network."""
|
|
319
|
-
stree = ms.rewrite.SymbolTree.create(network)
|
|
320
|
-
_insert_cast_operator_white_list(stree, white_list, dtype)
|
|
321
|
-
_remove_duplicated_cast(stree, dtype)
|
|
322
|
-
return stree.get_network()
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
def _insert_cast_operator_black_list(stree, black_list, dtype):
|
|
326
|
-
"""insert cast for operators not in black_list."""
|
|
327
|
-
allowed_list = []
|
|
328
|
-
# Ignore if net called ".to_float(dtype)"
|
|
329
|
-
net = stree.get_handler().get_origin_network()
|
|
330
|
-
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
331
|
-
if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
|
|
332
|
-
return
|
|
333
|
-
for node in stree.nodes(all_nodes=True):
|
|
334
|
-
if node.get_targets() is None:
|
|
335
|
-
continue
|
|
336
|
-
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
337
|
-
_insert_cast_for_cell_container(node, dtype, allowed_list, black_list=black_list)
|
|
338
|
-
elif isinstance(node.get_handler().get_node_manager(), ms.rewrite.node.CellContainer):
|
|
339
|
-
# nodes in CellContainer are processed by _insert_cast_for_cell_container
|
|
340
|
-
continue
|
|
341
|
-
elif node.get_instance_type() not in black_list and _allow_mix_precision(node, allowed_list, dtype):
|
|
342
|
-
_insert_cast_operator_process(node, dtype)
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
def _remove_duplicated_cast_rewrite(stree, dtype):
|
|
346
|
-
"""remove the duplicated cast operators."""
|
|
347
|
-
for node in stree.nodes(all_nodes=True):
|
|
304
|
+
all_nodes = list(stree.nodes(all_nodes=True))
|
|
305
|
+
for node in all_nodes:
|
|
348
306
|
if _need_removed_cast_pair(node, dtype):
|
|
349
|
-
|
|
350
|
-
# remove cast
|
|
351
|
-
for
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
307
|
+
incast_nodes = node.get_users()
|
|
308
|
+
# remove cast fp16/bf16 nodes
|
|
309
|
+
for incast_node in incast_nodes:
|
|
310
|
+
# get_target_users() return {target0: [(user0, arg_idx), ...], ...}
|
|
311
|
+
target_users = list(incast_node.get_target_users().values())
|
|
312
|
+
if not target_users or not target_users[0]:
|
|
313
|
+
continue
|
|
314
|
+
for user_node, arg_idx in target_users[0]:
|
|
315
|
+
user_node.set_arg(arg_idx, incast_node.get_args()[0])
|
|
316
|
+
stree.erase(incast_node)
|
|
317
|
+
# remove the cast fp32 node
|
|
355
318
|
stree.erase(node)
|
|
356
319
|
|
|
357
320
|
|
|
358
|
-
def
|
|
321
|
+
def _auto_mixed_precision_rewrite(network, dtype, *, white_list=None, black_list=None):
|
|
322
|
+
"""Implement auto mixed precision by rewrite"""
|
|
323
|
+
if (white_list is None and black_list is None) or (white_list is not None and black_list is not None):
|
|
324
|
+
raise ValueError("For _auto_mixed_precision_rewrite, one of white_list and black_list must be provided.")
|
|
325
|
+
# enable rewrite configs for amp
|
|
326
|
+
ms.rewrite.common.namespace._ms_cells_to_subtree = True
|
|
327
|
+
ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = True
|
|
328
|
+
# insert casts by rewrite
|
|
359
329
|
stree = ms.rewrite.SymbolTree.create(network)
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
330
|
+
_insert_cast_for_operators(stree, dtype, False, white_list=white_list, black_list=black_list)
|
|
331
|
+
_remove_duplicated_cast(stree, dtype)
|
|
332
|
+
new_net = stree.get_network()
|
|
333
|
+
# disable rewrite configs
|
|
334
|
+
ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = False
|
|
335
|
+
ms.rewrite.common.namespace._ms_cells_to_subtree = False
|
|
336
|
+
ms.rewrite.common.config.clear_caches()
|
|
337
|
+
return new_net
|
|
363
338
|
|
|
364
339
|
|
|
365
340
|
def _auto_black_list(network, black_list, dtype):
|
|
@@ -381,6 +356,42 @@ def _auto_black_list(network, black_list, dtype):
|
|
|
381
356
|
return network
|
|
382
357
|
|
|
383
358
|
|
|
359
|
+
class amp_decorator:
|
|
360
|
+
"""
|
|
361
|
+
Auto mixed precision decorator.
|
|
362
|
+
Type of lists: List[Tuple[str, List[int]]]
|
|
363
|
+
"""
|
|
364
|
+
def __init__(self, amp_level, amp_dtype, white_list, black_list):
|
|
365
|
+
self.amp_level = amp_level
|
|
366
|
+
self.amp_dtype = amp_dtype
|
|
367
|
+
self.white_list = white_list
|
|
368
|
+
self.black_list = black_list
|
|
369
|
+
|
|
370
|
+
def __enter__(self):
|
|
371
|
+
push_amp_strategy(self.amp_level, self.amp_dtype, self.white_list, self.black_list)
|
|
372
|
+
|
|
373
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
374
|
+
pop_amp_strategy()
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _set_amp_decorator(obj, amp_level, amp_dtype, white_list, black_list):
|
|
378
|
+
"""
|
|
379
|
+
Set auto mixed precision context decorator for object.
|
|
380
|
+
Type of lists: List[Tuple[str, List[int]]]
|
|
381
|
+
"""
|
|
382
|
+
if inspect.isfunction(obj) or inspect.ismethod(obj):
|
|
383
|
+
@functools.wraps(obj)
|
|
384
|
+
def wrapper(*args, **kwargs):
|
|
385
|
+
with amp_decorator(amp_level, amp_dtype, white_list, black_list):
|
|
386
|
+
return obj(*args, **kwargs)
|
|
387
|
+
return wrapper
|
|
388
|
+
if isinstance(obj, nn.Cell):
|
|
389
|
+
obj.construct = types.MethodType(
|
|
390
|
+
_set_amp_decorator(obj.construct.__func__, amp_level, amp_dtype, white_list, black_list), obj)
|
|
391
|
+
return obj
|
|
392
|
+
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell or function, bot got {type(obj)}.")
|
|
393
|
+
|
|
394
|
+
|
|
384
395
|
def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
385
396
|
"""
|
|
386
397
|
Returns a network processed with auto mixed precision.
|
|
@@ -391,26 +402,44 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
391
402
|
converted to lower precision float, and calculation results are converted back to full precision float,
|
|
392
403
|
i.e. ``mstype.float32`` .
|
|
393
404
|
|
|
394
|
-
The
|
|
395
|
-
operators are specifically converted.
|
|
405
|
+
The `amp_level` and its corresponding lists determine which cells and operators are converted.
|
|
396
406
|
|
|
397
|
-
|
|
407
|
+
When `amp_level` is set to ``O0``, no cells and operators are converted.
|
|
398
408
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
:class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
|
|
402
|
-
:class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
|
|
403
|
-
:class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
|
|
404
|
-
:class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
|
|
405
|
-
:class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
|
|
409
|
+
When `amp_level` is set to ``O1``, cells and operators in whitelist will be converted to lower precision
|
|
410
|
+
operations. For details on whitelist, refer to :func:`mindspore.amp.get_white_list`.
|
|
406
411
|
|
|
407
|
-
|
|
412
|
+
When `amp_level` is set to ``O2``, cells in blacklist will maintain full precision, and cells outside the
|
|
413
|
+
list will be converted to low precision. For details on blacklist, refer to :func:`mindspore.amp.get_black_list`.
|
|
408
414
|
|
|
409
|
-
|
|
410
|
-
|
|
415
|
+
When `amp_level` is set to ``O3``, all cells will be converted to low precision.
|
|
416
|
+
|
|
417
|
+
When `amp_level` is set to ``auto``, operators in `auto_whitelist` will be converted to lower precision
|
|
418
|
+
operations, operators in `auto_blacklist` will be converted to full precision operations, operators in
|
|
419
|
+
`promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators
|
|
420
|
+
not listed will run in the type defined by their inputs.
|
|
421
|
+
|
|
422
|
+
Operators in `auto_whitelist` are:
|
|
423
|
+
|
|
424
|
+
``Conv2D``, ``Conv3D``, ``Conv2DTranspose``, ``Conv3DTranspose``, ``Convolution``, ``MatMul``, ``MatMulExt``,
|
|
425
|
+
``BatchMatMul``, ``BatchMatMulExt``, ``PReLU``, ``Einsum``, ``Dense``, ``Addmm``
|
|
426
|
+
|
|
427
|
+
Operators in `auto_blacklist` are:
|
|
428
|
+
|
|
429
|
+
``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
|
|
430
|
+
``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
|
|
431
|
+
``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
|
|
432
|
+
``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
|
|
433
|
+
``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
|
|
434
|
+
``Norm``
|
|
435
|
+
|
|
436
|
+
Operators in `promote_list` are:
|
|
437
|
+
|
|
438
|
+
``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
|
|
439
|
+
``BiasAdd``
|
|
411
440
|
|
|
412
441
|
For details on automatic mixed precision, refer to
|
|
413
|
-
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/
|
|
442
|
+
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
|
|
414
443
|
|
|
415
444
|
Note:
|
|
416
445
|
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
@@ -418,10 +447,18 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
418
447
|
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
419
448
|
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
420
449
|
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
450
|
+
- When `amp_level` is set to ``auto``, the output of the network may be lower precision. In this case, you
|
|
451
|
+
may need to manually convert the type to avoid type inconsistency errors of the loss function.
|
|
452
|
+
- When `amp_level` is set to ``auto``, and cells in the network are configured with `to_float`, the accuracy
|
|
453
|
+
specified by `to_float` takes effect first.
|
|
454
|
+
|
|
455
|
+
.. warning::
|
|
456
|
+
``auto`` level of `amp_level` is an experimental API that is subject to change or deletion.
|
|
421
457
|
|
|
422
458
|
Args:
|
|
423
|
-
network (Cell): Definition of the network.
|
|
424
|
-
|
|
459
|
+
network (Union[Cell, function]): Definition of the network. Function type is supported only when `amp_level`
|
|
460
|
+
is set to ``auto`` .
|
|
461
|
+
amp_level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
|
|
425
462
|
|
|
426
463
|
- "O0": Do not change.
|
|
427
464
|
- "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
|
|
@@ -429,25 +466,34 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
429
466
|
- "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
|
|
430
467
|
to lower precision operations.
|
|
431
468
|
- "O3": Cast network to lower precision.
|
|
469
|
+
- "auto": Operators in `auto_whitelist` will be converted to lower precision operations, operators in
|
|
470
|
+
`auto_blacklist` will be converted to full precision, operators in `promote_list` will be converted
|
|
471
|
+
to the higher accuracy float type of the operator inputs, and operators not listed will run in the
|
|
472
|
+
type defined by their inputs.
|
|
432
473
|
|
|
433
474
|
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
434
475
|
default: ``mstype.float16`` .
|
|
435
476
|
|
|
436
477
|
Raises:
|
|
437
|
-
TypeError: If `network` is not a Cell.
|
|
478
|
+
TypeError: If `network` is not a Cell or a function.
|
|
438
479
|
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
439
480
|
ValueError: If `amp_level` is not within the supported range.
|
|
440
481
|
|
|
441
482
|
Examples:
|
|
442
483
|
>>> from mindspore import amp
|
|
443
484
|
>>> # Define the network structure of LeNet5. Refer to
|
|
444
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
485
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
445
486
|
>>> network = LeNet5()
|
|
446
487
|
>>> amp_level = "O1"
|
|
447
488
|
>>> net = amp.auto_mixed_precision(network, amp_level)
|
|
448
489
|
"""
|
|
449
490
|
if not isinstance(network, nn.Cell):
|
|
450
|
-
|
|
491
|
+
if amp_level == "auto":
|
|
492
|
+
if not inspect.isfunction(network) and not inspect.ismethod(network):
|
|
493
|
+
raise TypeError("For amp_level 'auto', the network type should be Cell or function.")
|
|
494
|
+
# function is supported for amp_level 'auto'
|
|
495
|
+
else:
|
|
496
|
+
raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell.")
|
|
451
497
|
|
|
452
498
|
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
453
499
|
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
@@ -456,27 +502,35 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
456
502
|
return network
|
|
457
503
|
|
|
458
504
|
# Return network if the same amp level has already been configurated
|
|
459
|
-
if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
|
|
505
|
+
if hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O1", "O2", "O3", "auto"):
|
|
460
506
|
logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
|
|
461
507
|
f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
|
|
462
508
|
f"degradation.")
|
|
463
509
|
|
|
464
510
|
if amp_level == "O1":
|
|
465
|
-
network =
|
|
511
|
+
network = _auto_mixed_precision_rewrite(network, dtype, white_list=AMP_WHITE_LIST)
|
|
466
512
|
elif amp_level == "O2":
|
|
467
513
|
if MS_AMP_BY_REWRITE:
|
|
468
|
-
network =
|
|
514
|
+
network = _auto_mixed_precision_rewrite(network, dtype, black_list=AMP_BLACK_LIST)
|
|
469
515
|
else:
|
|
470
516
|
network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
|
|
471
517
|
network = _OutputTo32(network)
|
|
472
518
|
elif amp_level == "O3":
|
|
473
519
|
if MS_AMP_BY_REWRITE:
|
|
474
|
-
network =
|
|
520
|
+
network = _auto_mixed_precision_rewrite(network, dtype, black_list=[])
|
|
475
521
|
else:
|
|
476
522
|
network.to_float(dtype)
|
|
477
523
|
network = _OutputTo32(network)
|
|
524
|
+
elif amp_level == "auto":
|
|
525
|
+
white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST]
|
|
526
|
+
black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST]
|
|
527
|
+
# set amp_strategy attribute for the object
|
|
528
|
+
amp_strategy = create_amp_strategy(AmpLevel.AmpAuto, dtype, white_list, black_list)
|
|
529
|
+
setattr(network, "amp_strategy", amp_strategy)
|
|
530
|
+
# set amp_strategy context decorator for the object
|
|
531
|
+
network = _set_amp_decorator(network, AmpLevel.AmpAuto, dtype, white_list, black_list)
|
|
478
532
|
else:
|
|
479
|
-
raise ValueError("The amp level {} is not supported"
|
|
533
|
+
raise ValueError(f"The amp level {amp_level} is not supported")
|
|
480
534
|
|
|
481
535
|
setattr(network, "_amp_level", amp_level)
|
|
482
536
|
|
|
@@ -516,6 +570,10 @@ _config_level = {
|
|
|
516
570
|
"O3": {
|
|
517
571
|
"keep_batchnorm_fp32": False,
|
|
518
572
|
"cast_model_type": mstype.float16,
|
|
573
|
+
"loss_scale_manager": None},
|
|
574
|
+
"auto": {
|
|
575
|
+
"keep_batchnorm_fp32": False,
|
|
576
|
+
"cast_model_type": mstype.float32,
|
|
519
577
|
"loss_scale_manager": None}}
|
|
520
578
|
|
|
521
579
|
|
|
@@ -540,20 +598,11 @@ def _check_kwargs(key_words):
|
|
|
540
598
|
def _check_level(level, boost_level):
|
|
541
599
|
"""Check level."""
|
|
542
600
|
if not isinstance(level, str):
|
|
543
|
-
raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],
|
|
544
|
-
|
|
601
|
+
raise TypeError(f"The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],"
|
|
602
|
+
f"but got type {type(level)}.")
|
|
545
603
|
validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
|
|
546
604
|
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
|
|
547
605
|
|
|
548
|
-
if level == "auto":
|
|
549
|
-
device_target = context.get_context('device_target')
|
|
550
|
-
if device_target == "GPU":
|
|
551
|
-
level = "O2"
|
|
552
|
-
elif device_target == "Ascend":
|
|
553
|
-
level = "O3"
|
|
554
|
-
else:
|
|
555
|
-
raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
|
|
556
|
-
|
|
557
606
|
enable_boost = False
|
|
558
607
|
if boost_level in ["O1", "O2"]:
|
|
559
608
|
enable_boost = True
|
|
@@ -578,7 +627,8 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
578
627
|
return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
|
|
579
628
|
|
|
580
629
|
validator.check_value_type('loss_fn', loss_fn, nn.Cell)
|
|
581
|
-
if cast_model_type
|
|
630
|
+
if cast_model_type in (mstype.float16, mstype.bfloat16) or \
|
|
631
|
+
(hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O2", "O3", "auto")):
|
|
582
632
|
network = WithLossCell(network, loss_fn)
|
|
583
633
|
else:
|
|
584
634
|
network = nn.WithLossCell(network, loss_fn)
|
|
@@ -634,20 +684,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
634
684
|
Default: ``None`` .
|
|
635
685
|
level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
|
|
636
686
|
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
- 'O2': Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
|
642
|
-
using dynamic loss scale.
|
|
643
|
-
- 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
|
644
|
-
- 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set
|
|
645
|
-
level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all
|
|
646
|
-
scenarios. User should specify the level for special network.
|
|
647
|
-
|
|
648
|
-
'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`,
|
|
649
|
-
`cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in
|
|
650
|
-
`kwargs`.
|
|
687
|
+
For details on amp level, refer to :func:`mindspore.amp.auto_mixed_precision`.
|
|
688
|
+
|
|
689
|
+
Property of `keep_batchnorm_fp32`, `cast_model_type` and `loss_scale_manager` determined by `level`
|
|
690
|
+
setting may be overwritten by settings in `kwargs`.
|
|
651
691
|
|
|
652
692
|
boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
|
|
653
693
|
training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
|
|
@@ -670,13 +710,13 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
670
710
|
take no effect on this property.
|
|
671
711
|
|
|
672
712
|
Raises:
|
|
673
|
-
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or
|
|
674
|
-
(with property `drop_overflow_update=False` ).
|
|
713
|
+
ValueError: If device is CPU, property `loss_scale_manager` is not `None` or
|
|
714
|
+
:class:`mindspore.amp.FixedLossScaleManager` (with property `drop_overflow_update=False` ).
|
|
675
715
|
|
|
676
716
|
Examples:
|
|
677
717
|
>>> from mindspore import amp, nn
|
|
678
718
|
>>> # Define the network structure of LeNet5. Refer to
|
|
679
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
719
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
680
720
|
>>> network = LeNet5()
|
|
681
721
|
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
682
722
|
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
|
|
@@ -728,7 +768,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
728
768
|
|
|
729
769
|
def get_white_list():
|
|
730
770
|
"""
|
|
731
|
-
Provide a copy of internal white list used by auto mixed precision
|
|
771
|
+
Provide a copy of internal white list used by auto mixed precision with `amp_level` set to ``O1``.
|
|
732
772
|
|
|
733
773
|
The current built-in whitelist contents are:
|
|
734
774
|
|
|
@@ -766,7 +806,7 @@ def get_white_list():
|
|
|
766
806
|
|
|
767
807
|
def get_black_list():
|
|
768
808
|
"""
|
|
769
|
-
Provide a copy of internal black list used by auto mixed precision
|
|
809
|
+
Provide a copy of internal black list used by auto mixed precision with `amp_level` set to ``O2``.
|
|
770
810
|
|
|
771
811
|
The current built-in blacklist contents are:
|
|
772
812
|
|
|
@@ -789,7 +829,6 @@ def get_black_list():
|
|
|
789
829
|
|
|
790
830
|
def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
|
|
791
831
|
"""
|
|
792
|
-
Custom mixed precision by setting whitelist or blacklist.
|
|
793
832
|
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
|
|
794
833
|
When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
|
|
795
834
|
Only one of `white_list` and `black_list` should be provided.
|
|
@@ -823,7 +862,7 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
|
|
|
823
862
|
Examples:
|
|
824
863
|
>>> from mindspore import amp, nn
|
|
825
864
|
>>> # Define the network structure of LeNet5. Refer to
|
|
826
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
865
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
827
866
|
>>> net = LeNet5()
|
|
828
867
|
>>> custom_white_list = amp.get_white_list()
|
|
829
868
|
>>> custom_white_list.append(nn.Flatten)
|
|
@@ -844,11 +883,11 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
|
|
|
844
883
|
|
|
845
884
|
if white_list is not None:
|
|
846
885
|
_list_check(white_list, "white_list")
|
|
847
|
-
network =
|
|
886
|
+
network = _auto_mixed_precision_rewrite(network, dtype, white_list=white_list)
|
|
848
887
|
else:
|
|
849
888
|
_list_check(black_list, "black_list")
|
|
850
889
|
if MS_AMP_BY_REWRITE:
|
|
851
|
-
network =
|
|
890
|
+
network = _auto_mixed_precision_rewrite(network, dtype, black_list=black_list)
|
|
852
891
|
else:
|
|
853
892
|
network = _auto_black_list(network, black_list, dtype)
|
|
854
893
|
network = _OutputTo32(network)
|
|
@@ -883,7 +922,8 @@ def _list_check(custom_list: list, list_name: str):
|
|
|
883
922
|
if elem not in custom_list:
|
|
884
923
|
logger.warning(f"{elem} is removed from internal black list.")
|
|
885
924
|
|
|
886
|
-
|
|
925
|
+
|
|
926
|
+
def _config_amp(*, enable_rewrite: bool = None, cast_op: types.FunctionType = None): # pylint: disable=unused-variable
|
|
887
927
|
"""Configure auto mixed precision."""
|
|
888
928
|
global MS_AMP_BY_REWRITE
|
|
889
929
|
global _amp_cast_op
|