mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.10__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/train/amp.py
CHANGED
|
@@ -19,8 +19,8 @@ import mindspore as ms
|
|
|
19
19
|
from mindspore import nn
|
|
20
20
|
from mindspore import _checkparam as validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
|
-
from mindspore.nn.wrap.cell_wrapper import
|
|
23
|
-
from mindspore.nn.wrap.loss_scale import
|
|
22
|
+
from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
|
|
23
|
+
from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
|
|
24
24
|
from mindspore.ops import functional as F
|
|
25
25
|
from mindspore.parallel._utils import _get_pipeline_stages
|
|
26
26
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
|
@@ -30,9 +30,6 @@ from mindspore.ops import Primitive
|
|
|
30
30
|
from mindspore import log as logger
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
STREE = None
|
|
34
|
-
|
|
35
|
-
|
|
36
33
|
AMP_WHITE_LIST = [
|
|
37
34
|
nn.Conv1d,
|
|
38
35
|
nn.Conv2d,
|
|
@@ -64,17 +61,19 @@ AMP_BLACK_LIST = [
|
|
|
64
61
|
nn.LayerNorm
|
|
65
62
|
]
|
|
66
63
|
|
|
64
|
+
MS_AMP_BY_REWRITE = False
|
|
65
|
+
_amp_cast_op = P.Cast
|
|
67
66
|
|
|
68
67
|
class _OutputTo16(nn.Cell):
|
|
69
68
|
"""Wrap cell for amp. Cast network output back to float16."""
|
|
70
|
-
def __init__(self, backbone):
|
|
69
|
+
def __init__(self, backbone, dtype=mstype.float16):
|
|
71
70
|
super(_OutputTo16, self).__init__(auto_prefix=False)
|
|
72
71
|
self._backbone = backbone
|
|
73
|
-
|
|
74
|
-
|
|
72
|
+
self.dtype = dtype
|
|
73
|
+
self._get_attr_from_cell(backbone)
|
|
75
74
|
|
|
76
75
|
def construct(self, *args, **kwargs):
|
|
77
|
-
return F.cast(self._backbone(*args, **kwargs),
|
|
76
|
+
return F.cast(self._backbone(*args, **kwargs), self.dtype)
|
|
78
77
|
|
|
79
78
|
|
|
80
79
|
class _OutputTo32(nn.Cell):
|
|
@@ -82,63 +81,73 @@ class _OutputTo32(nn.Cell):
|
|
|
82
81
|
def __init__(self, backbone):
|
|
83
82
|
super(_OutputTo32, self).__init__(auto_prefix=False)
|
|
84
83
|
self._backbone = backbone
|
|
85
|
-
|
|
86
|
-
self._jit_config_dict = backbone.jit_config_dict
|
|
84
|
+
self._get_attr_from_cell(backbone)
|
|
87
85
|
|
|
88
86
|
def construct(self, *args, **kwargs):
|
|
89
87
|
out = self._backbone(*args, **kwargs)
|
|
90
88
|
return F.mixed_precision_cast(mstype.float32, out)
|
|
91
89
|
|
|
92
90
|
|
|
93
|
-
|
|
91
|
+
|
|
92
|
+
def _allow_mix_precision(node, allowed_list, dtype) -> bool:
|
|
94
93
|
"""
|
|
95
94
|
Check whether current node need do mix precision. Follow conditions need to be satisfied:
|
|
96
95
|
1) Type of node is one of (Primitive, nn.Cell)
|
|
97
|
-
2) Node is not
|
|
96
|
+
2) Node is not Cast Op
|
|
98
97
|
3) to_float(mindspore.float16) is not set in Cell
|
|
99
98
|
"""
|
|
100
|
-
|
|
99
|
+
node_inst = node.get_instance()
|
|
100
|
+
if node_inst in allowed_list:
|
|
101
101
|
return True
|
|
102
|
+
if node.get_targets() is None:
|
|
103
|
+
return False
|
|
102
104
|
if not issubclass(node.get_instance_type(), (Primitive, nn.Cell)):
|
|
103
105
|
return False
|
|
104
|
-
if isinstance(
|
|
106
|
+
if isinstance(node_inst, _amp_cast_op):
|
|
105
107
|
return False
|
|
106
108
|
if issubclass(node.get_instance_type(), nn.Cell):
|
|
107
|
-
# if cell is already in allowed_list, it means to_float(
|
|
108
|
-
# if cell is not in allowed_list, but has to_float(
|
|
109
|
-
# it means to_float(
|
|
110
|
-
|
|
109
|
+
# if cell is already in allowed_list, it means to_float() is set by amp.
|
|
110
|
+
# if cell is not in allowed_list, but has to_float(),
|
|
111
|
+
# it means to_float() is set by user.
|
|
112
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
113
|
+
if hasattr(node_inst, to_float_flag) and getattr(node_inst, to_float_flag):
|
|
111
114
|
return False
|
|
112
115
|
allowed_list.append(node.get_instance())
|
|
113
116
|
return True
|
|
114
117
|
|
|
115
118
|
|
|
116
|
-
def _insert_cast_operator_process(node,
|
|
119
|
+
def _insert_cast_operator_process(node, dtype):
|
|
117
120
|
"""insert cast for operators in white_list."""
|
|
121
|
+
dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
|
|
118
122
|
new_cast_node = None
|
|
119
|
-
|
|
123
|
+
stree = node.get_symbol_tree()
|
|
124
|
+
# insert cast fp16/bf16 before the primitive operators
|
|
120
125
|
if issubclass(node.get_instance_type(), Primitive):
|
|
121
126
|
for idx, arg in enumerate(node.get_args()):
|
|
122
127
|
position = stree.before(node)
|
|
123
|
-
new_node =
|
|
124
|
-
cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value,
|
|
125
|
-
|
|
128
|
+
new_node = _amp_cast_op()
|
|
129
|
+
cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, ""])
|
|
130
|
+
arg_provider = node.get_handler().get_arg_providers()[idx]
|
|
131
|
+
if arg_provider and len(arg_provider[0].get_target_users(arg_provider[1])) > 1:
|
|
132
|
+
cast_targets = [stree.unique_name(str(arg))]
|
|
133
|
+
else:
|
|
134
|
+
cast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
|
|
126
135
|
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
127
136
|
targets=cast_targets,
|
|
128
137
|
args=cast_args,
|
|
129
138
|
name='incast_{}{}'.format(node.get_name(), idx))
|
|
130
139
|
stree.insert(position, new_cast_node)
|
|
131
140
|
node.set_arg_by_node(idx, new_cast_node)
|
|
132
|
-
# insert cast
|
|
141
|
+
# insert cast fp16/bf16 before the Cell operators
|
|
133
142
|
elif issubclass(node.get_instance_type(), nn.Cell):
|
|
134
|
-
node.get_instance().to_float(
|
|
143
|
+
node.get_instance().to_float(dtype)
|
|
135
144
|
# ignore if subclass is not one of (Primitive, nn.Cell)
|
|
136
145
|
else:
|
|
137
146
|
return
|
|
138
147
|
|
|
139
148
|
# insert cast float32 after the operators
|
|
140
149
|
position = stree.after(node)
|
|
141
|
-
new_node =
|
|
150
|
+
new_node = _amp_cast_op()
|
|
142
151
|
cast_args = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
|
143
152
|
"mindspore.float32"])
|
|
144
153
|
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
@@ -156,49 +165,102 @@ def _insert_cast_operator_process(node, stree):
|
|
|
156
165
|
user.set_arg_by_node(idx, new_cast_node)
|
|
157
166
|
|
|
158
167
|
|
|
159
|
-
def _insert_cast_operator_white_list(stree, white_list):
|
|
168
|
+
def _insert_cast_operator_white_list(stree, white_list, dtype):
|
|
160
169
|
"""insert cast for operators in white_list."""
|
|
161
170
|
allowed_list = []
|
|
162
|
-
# Ignore if net called ".to_float(
|
|
171
|
+
# Ignore if net called ".to_float(dtype)"
|
|
163
172
|
net = stree.get_handler().get_origin_network()
|
|
164
|
-
if
|
|
173
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
174
|
+
if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
|
|
165
175
|
return
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
176
|
+
node_list = []
|
|
177
|
+
node_list.extend(list(stree.nodes()))
|
|
178
|
+
while node_list:
|
|
179
|
+
node = node_list.pop()
|
|
169
180
|
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
181
|
+
if MS_AMP_BY_REWRITE:
|
|
182
|
+
_insert_cast_for_cell_container(node, dtype, allowed_list, white_list=white_list)
|
|
170
183
|
for n in node.get_handler().node_list:
|
|
171
184
|
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
172
185
|
_insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
|
|
173
|
-
white_list)
|
|
186
|
+
white_list, dtype)
|
|
174
187
|
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
175
188
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
176
|
-
_insert_cast_operator_white_list(substree, white_list)
|
|
177
|
-
elif node.
|
|
178
|
-
|
|
189
|
+
_insert_cast_operator_white_list(substree, white_list, dtype)
|
|
190
|
+
elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
|
|
191
|
+
if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
|
|
192
|
+
nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
|
|
193
|
+
node_list.extend(nodes)
|
|
194
|
+
elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list, dtype):
|
|
195
|
+
_insert_cast_operator_process(node, dtype)
|
|
179
196
|
|
|
180
197
|
|
|
181
|
-
def
|
|
198
|
+
def _insert_cast_for_cell_container(cell_container, dtype, allowed_list, *, white_list=None, black_list=None):
|
|
199
|
+
"""
|
|
200
|
+
Insert cast for cell containers.
|
|
201
|
+
Only one of white_list and black_list can be set.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
class CastNet(nn.Cell):
|
|
205
|
+
"""Cast net"""
|
|
206
|
+
def __init__(self, dtype):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.cast = _amp_cast_op()
|
|
209
|
+
self.dtype = dtype
|
|
210
|
+
|
|
211
|
+
def construct(self, x):
|
|
212
|
+
return self.cast(x, self.dtype)
|
|
213
|
+
|
|
214
|
+
cast_flag = False
|
|
215
|
+
current_node = None
|
|
216
|
+
stree = cell_container.get_symbol_tree()
|
|
217
|
+
for node in cell_container.get_handler().nodes():
|
|
218
|
+
current_node = ms.rewrite.Node(node)
|
|
219
|
+
if (white_list is not None and current_node.get_instance_type() in white_list) or \
|
|
220
|
+
(black_list is not None and current_node.get_instance_type() not in black_list) and \
|
|
221
|
+
(_allow_mix_precision(current_node, allowed_list, dtype)):
|
|
222
|
+
cast_flag = True
|
|
223
|
+
current_node.get_instance().to_float(dtype)
|
|
224
|
+
elif cast_flag:
|
|
225
|
+
# cast next node back to float32
|
|
226
|
+
current_node.get_instance().to_float(mstype.float32)
|
|
227
|
+
cast_flag = False
|
|
228
|
+
if cast_flag and current_node:
|
|
229
|
+
# if last node in cell_container is casted to fp16/bf16, insert a cast node to cast value back to fp32
|
|
230
|
+
cast_node = ms.rewrite.Node.create_call_cell(cell=CastNet(mstype.float32),
|
|
231
|
+
args=[current_node.get_targets()[0]],
|
|
232
|
+
targets=[current_node.get_targets()[0]],
|
|
233
|
+
name=f"outcast_{cell_container.get_name()}")
|
|
234
|
+
stree.insert(stree.after(current_node), cast_node)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _need_removed_cast_pair(node, dtype):
|
|
182
238
|
"""check whether the cast pairs should be removed."""
|
|
183
|
-
|
|
239
|
+
dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
|
|
240
|
+
cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "mindspore.float32"])
|
|
184
241
|
cast_dtype_f16 = cast_dtypes[0]
|
|
185
242
|
cast_dtype_f32 = cast_dtypes[1]
|
|
186
|
-
# current node should be
|
|
187
|
-
if node.get_instance_type() !=
|
|
243
|
+
# current node should be Cast Op to float32
|
|
244
|
+
if node.get_instance_type() != _amp_cast_op:
|
|
188
245
|
return False
|
|
189
246
|
node_cast_type = node.get_args()[1]
|
|
190
247
|
if node_cast_type != cast_dtype_f32:
|
|
191
248
|
return False
|
|
192
|
-
# all user nodes should be
|
|
249
|
+
# all user nodes should be Cast Op to dtype or Cell with to_float(dtype)
|
|
193
250
|
if not node.get_users():
|
|
194
251
|
return False
|
|
252
|
+
all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
|
|
195
253
|
for user in node.get_users():
|
|
254
|
+
# If ControlFlow node(if statement) exists between current node and user node,
|
|
255
|
+
# cast pair should not be removed.
|
|
256
|
+
middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
|
|
257
|
+
if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
|
|
258
|
+
return False
|
|
196
259
|
if isinstance(user.get_instance(), nn.Cell):
|
|
197
|
-
if
|
|
198
|
-
|
|
199
|
-
if not user.get_instance().to_float_fp16:
|
|
260
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
261
|
+
if not (hasattr(user.get_instance(), to_float_flag) and getattr(user.get_instance(), to_float_flag)):
|
|
200
262
|
return False
|
|
201
|
-
elif user.get_instance_type() ==
|
|
263
|
+
elif user.get_instance_type() == _amp_cast_op:
|
|
202
264
|
user_cast_type = user.get_args()[1]
|
|
203
265
|
if user_cast_type != cast_dtype_f16:
|
|
204
266
|
return False
|
|
@@ -207,11 +269,13 @@ def _need_removed_cast_pair(node):
|
|
|
207
269
|
return True
|
|
208
270
|
|
|
209
271
|
|
|
210
|
-
def _removed_cast_pair_process(
|
|
272
|
+
def _removed_cast_pair_process(cast_f32_node):
|
|
211
273
|
"""remove the duplicated cast operators."""
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
274
|
+
stree = cast_f32_node.get_symbol_tree()
|
|
275
|
+
cast_f32_users = cast_f32_node.get_users()
|
|
276
|
+
# remove cast f16 nodes
|
|
277
|
+
for user_node in cast_f32_users:
|
|
278
|
+
if user_node.get_instance_type() == _amp_cast_op:
|
|
215
279
|
cast_f16_node = user_node
|
|
216
280
|
# modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
|
|
217
281
|
for cast_f16_user in cast_f16_node.get_users():
|
|
@@ -229,34 +293,78 @@ def _removed_cast_pair_process(stree, cast_f32_node):
|
|
|
229
293
|
stree.erase(cast_f32_node)
|
|
230
294
|
|
|
231
295
|
|
|
232
|
-
def _remove_duplicated_cast(stree):
|
|
296
|
+
def _remove_duplicated_cast(stree, dtype):
|
|
233
297
|
"""remove the duplicated cast operators."""
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
298
|
+
node_list = []
|
|
299
|
+
node_list.extend(list(stree.nodes()))
|
|
300
|
+
while node_list:
|
|
301
|
+
node = node_list.pop()
|
|
237
302
|
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
238
303
|
for n in node.get_handler().node_list:
|
|
239
304
|
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
240
|
-
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
|
|
305
|
+
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), dtype)
|
|
241
306
|
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
242
307
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
243
|
-
_remove_duplicated_cast(substree)
|
|
244
|
-
elif
|
|
245
|
-
|
|
308
|
+
_remove_duplicated_cast(substree, dtype)
|
|
309
|
+
elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
|
|
310
|
+
if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
|
|
311
|
+
nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
|
|
312
|
+
node_list.extend(nodes)
|
|
313
|
+
elif _need_removed_cast_pair(node, dtype):
|
|
314
|
+
_removed_cast_pair_process(node)
|
|
246
315
|
|
|
247
316
|
|
|
248
|
-
def _auto_white_list(network, white_list):
|
|
317
|
+
def _auto_white_list(network, white_list, dtype):
|
|
249
318
|
"""process the white list of network."""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
return STREE.get_network()
|
|
319
|
+
stree = ms.rewrite.SymbolTree.create(network)
|
|
320
|
+
_insert_cast_operator_white_list(stree, white_list, dtype)
|
|
321
|
+
_remove_duplicated_cast(stree, dtype)
|
|
322
|
+
return stree.get_network()
|
|
255
323
|
|
|
256
324
|
|
|
257
|
-
def
|
|
325
|
+
def _insert_cast_operator_black_list(stree, black_list, dtype):
|
|
326
|
+
"""insert cast for operators not in black_list."""
|
|
327
|
+
allowed_list = []
|
|
328
|
+
# Ignore if net called ".to_float(dtype)"
|
|
329
|
+
net = stree.get_handler().get_origin_network()
|
|
330
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
331
|
+
if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
|
|
332
|
+
return
|
|
333
|
+
for node in stree.nodes(all_nodes=True):
|
|
334
|
+
if node.get_targets() is None:
|
|
335
|
+
continue
|
|
336
|
+
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
337
|
+
_insert_cast_for_cell_container(node, dtype, allowed_list, black_list=black_list)
|
|
338
|
+
elif isinstance(node.get_handler().get_node_manager(), ms.rewrite.node.CellContainer):
|
|
339
|
+
# nodes in CellContainer are processed by _insert_cast_for_cell_container
|
|
340
|
+
continue
|
|
341
|
+
elif node.get_instance_type() not in black_list and _allow_mix_precision(node, allowed_list, dtype):
|
|
342
|
+
_insert_cast_operator_process(node, dtype)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _remove_duplicated_cast_rewrite(stree, dtype):
|
|
346
|
+
"""remove the duplicated cast operators."""
|
|
347
|
+
for node in stree.nodes(all_nodes=True):
|
|
348
|
+
if _need_removed_cast_pair(node, dtype):
|
|
349
|
+
user_nodes = node.get_users()
|
|
350
|
+
# remove cast f16 nodes
|
|
351
|
+
for user_node in user_nodes:
|
|
352
|
+
if user_node.get_instance_type() == _amp_cast_op:
|
|
353
|
+
stree.erase(user_node)
|
|
354
|
+
# remove the cast f32 node
|
|
355
|
+
stree.erase(node)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _auto_black_list_rewrite(network, black_list, dtype):
|
|
359
|
+
stree = ms.rewrite.SymbolTree.create(network)
|
|
360
|
+
_insert_cast_operator_black_list(stree, black_list, dtype)
|
|
361
|
+
_remove_duplicated_cast_rewrite(stree, dtype)
|
|
362
|
+
return stree.get_network()
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _auto_black_list(network, black_list, dtype):
|
|
258
366
|
"""process the black list of network."""
|
|
259
|
-
network.to_float(
|
|
367
|
+
network.to_float(dtype)
|
|
260
368
|
cells = network.name_cells()
|
|
261
369
|
change = False
|
|
262
370
|
for name in cells:
|
|
@@ -264,30 +372,27 @@ def _auto_black_list(network, black_list):
|
|
|
264
372
|
if subcell == network:
|
|
265
373
|
continue
|
|
266
374
|
if isinstance(subcell, tuple(black_list)):
|
|
267
|
-
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
|
375
|
+
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
|
|
268
376
|
change = True
|
|
269
377
|
else:
|
|
270
|
-
_auto_black_list(subcell, black_list)
|
|
378
|
+
_auto_black_list(subcell, black_list, dtype)
|
|
271
379
|
if isinstance(network, nn.SequentialCell) and change:
|
|
272
380
|
network.cell_list = list(network.cells())
|
|
381
|
+
return network
|
|
273
382
|
|
|
274
383
|
|
|
275
|
-
def auto_mixed_precision(network, amp_level="O0"):
|
|
384
|
+
def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
276
385
|
"""
|
|
277
386
|
Returns a network processed with auto mixed precision.
|
|
278
387
|
|
|
279
388
|
This interface will automatically perform mixed-precision processing on the input network, and the cells
|
|
280
|
-
and operators in the processed network will add precision conversion operations to calculate with
|
|
281
|
-
Inputs and parameters of cells and operators are
|
|
282
|
-
back to
|
|
389
|
+
and operators in the processed network will add precision conversion operations to calculate with lower
|
|
390
|
+
precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are
|
|
391
|
+
converted to lower precision float, and calculation results are converted back to full precision float,
|
|
392
|
+
i.e. ``mstype.float32`` .
|
|
283
393
|
|
|
284
394
|
The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and
|
|
285
|
-
operators are specifically converted
|
|
286
|
-
|
|
287
|
-
- When `amp_level="O0"` , no precision conversion is performed.
|
|
288
|
-
- When `amp_level="O1"` , only the cells and operators in the whitelist will be converted.
|
|
289
|
-
- When `amp_level="O2"` , all cells and operators except those in the blacklist will be converted.
|
|
290
|
-
- When `amp_level="O3"` , all cells and operators in the network are converted.
|
|
395
|
+
operators are specifically converted.
|
|
291
396
|
|
|
292
397
|
The current built-in whitelist contents are:
|
|
293
398
|
|
|
@@ -305,26 +410,38 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|
|
305
410
|
:class:`mindspore.nn.LayerNorm`]
|
|
306
411
|
|
|
307
412
|
For details on automatic mixed precision, refer to
|
|
308
|
-
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.
|
|
413
|
+
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.2/advanced/mixed_precision.html>`_ .
|
|
414
|
+
|
|
415
|
+
Note:
|
|
416
|
+
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
417
|
+
can result in a larger network hierarchy and slower performance.
|
|
418
|
+
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
419
|
+
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
420
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
309
421
|
|
|
310
422
|
Args:
|
|
311
423
|
network (Cell): Definition of the network.
|
|
312
424
|
amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` .
|
|
313
425
|
|
|
314
426
|
- "O0": Do not change.
|
|
315
|
-
- "O1": Convert cells and operators in whitelist to
|
|
427
|
+
- "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
|
|
316
428
|
precision operations for the rest.
|
|
317
|
-
- "O2": Keep
|
|
318
|
-
to
|
|
319
|
-
- "O3": Cast network to
|
|
429
|
+
- "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
|
|
430
|
+
to lower precision operations.
|
|
431
|
+
- "O3": Cast network to lower precision.
|
|
432
|
+
|
|
433
|
+
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
434
|
+
default: ``mstype.float16`` .
|
|
320
435
|
|
|
321
436
|
Raises:
|
|
322
|
-
|
|
437
|
+
TypeError: If `network` is not a Cell.
|
|
438
|
+
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
439
|
+
ValueError: If `amp_level` is not within the supported range.
|
|
323
440
|
|
|
324
441
|
Examples:
|
|
325
442
|
>>> from mindspore import amp
|
|
326
443
|
>>> # Define the network structure of LeNet5. Refer to
|
|
327
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
444
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
328
445
|
>>> network = LeNet5()
|
|
329
446
|
>>> amp_level = "O1"
|
|
330
447
|
>>> net = amp.auto_mixed_precision(network, amp_level)
|
|
@@ -332,20 +449,37 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|
|
332
449
|
if not isinstance(network, nn.Cell):
|
|
333
450
|
raise TypeError("The network type should be Cell.")
|
|
334
451
|
|
|
452
|
+
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
453
|
+
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
454
|
+
|
|
335
455
|
if amp_level == "O0":
|
|
336
456
|
return network
|
|
337
457
|
|
|
338
|
-
if
|
|
339
|
-
|
|
458
|
+
# Return network if the same amp level has already been configurated
|
|
459
|
+
if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
|
|
460
|
+
logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
|
|
461
|
+
f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
|
|
462
|
+
f"degradation.")
|
|
340
463
|
|
|
341
|
-
if amp_level == "
|
|
342
|
-
|
|
464
|
+
if amp_level == "O1":
|
|
465
|
+
network = _auto_white_list(network, AMP_WHITE_LIST, dtype)
|
|
466
|
+
elif amp_level == "O2":
|
|
467
|
+
if MS_AMP_BY_REWRITE:
|
|
468
|
+
network = _auto_black_list_rewrite(network, AMP_BLACK_LIST, dtype)
|
|
469
|
+
else:
|
|
470
|
+
network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
|
|
471
|
+
network = _OutputTo32(network)
|
|
343
472
|
elif amp_level == "O3":
|
|
344
|
-
|
|
473
|
+
if MS_AMP_BY_REWRITE:
|
|
474
|
+
network = _auto_black_list_rewrite(network, [], dtype)
|
|
475
|
+
else:
|
|
476
|
+
network.to_float(dtype)
|
|
477
|
+
network = _OutputTo32(network)
|
|
345
478
|
else:
|
|
346
479
|
raise ValueError("The amp level {} is not supported".format(amp_level))
|
|
347
|
-
|
|
348
|
-
|
|
480
|
+
|
|
481
|
+
setattr(network, "_amp_level", amp_level)
|
|
482
|
+
|
|
349
483
|
return network
|
|
350
484
|
|
|
351
485
|
|
|
@@ -436,8 +570,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
436
570
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
437
571
|
self._backbone = backbone
|
|
438
572
|
self._loss_fn = loss_fn
|
|
439
|
-
|
|
440
|
-
self._jit_config_dict = backbone.jit_config_dict
|
|
573
|
+
self._get_attr_from_cell(backbone)
|
|
441
574
|
|
|
442
575
|
def construct(self, data, label):
|
|
443
576
|
out = self._backbone(data)
|
|
@@ -452,6 +585,39 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
452
585
|
return network
|
|
453
586
|
|
|
454
587
|
|
|
588
|
+
def _is_grad_accumulation(mcell):
|
|
589
|
+
if mcell.cls_name == "GradAccumulationCell":
|
|
590
|
+
return True
|
|
591
|
+
for cell in mcell.cells():
|
|
592
|
+
if _is_grad_accumulation(cell):
|
|
593
|
+
return True
|
|
594
|
+
return False
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def _auto_mixed_precision_process(network, config, level):
|
|
598
|
+
"""Auto mixed precision process."""
|
|
599
|
+
if MS_AMP_BY_REWRITE:
|
|
600
|
+
if config["cast_model_type"] == mstype.float16 or level == "O2":
|
|
601
|
+
level = "O2" if config["keep_batchnorm_fp32"] else "O3"
|
|
602
|
+
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
603
|
+
# cast_model_type set by kwargs
|
|
604
|
+
level = "O0"
|
|
605
|
+
network = auto_mixed_precision(network, level)
|
|
606
|
+
else:
|
|
607
|
+
if config["cast_model_type"] == mstype.float16:
|
|
608
|
+
network.to_float(mstype.float16)
|
|
609
|
+
|
|
610
|
+
if config["keep_batchnorm_fp32"]:
|
|
611
|
+
_do_keep_batchnorm_fp32(network)
|
|
612
|
+
elif not config["keep_batchnorm_fp32"] and level == "O2":
|
|
613
|
+
network.to_float(mstype.float16)
|
|
614
|
+
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
615
|
+
pass
|
|
616
|
+
else:
|
|
617
|
+
network = auto_mixed_precision(network, level)
|
|
618
|
+
return network
|
|
619
|
+
|
|
620
|
+
|
|
455
621
|
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
|
|
456
622
|
"""
|
|
457
623
|
Build the mixed precision training cell automatically.
|
|
@@ -510,7 +676,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
510
676
|
Examples:
|
|
511
677
|
>>> from mindspore import amp, nn
|
|
512
678
|
>>> # Define the network structure of LeNet5. Refer to
|
|
513
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
679
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
514
680
|
>>> network = LeNet5()
|
|
515
681
|
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
516
682
|
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
|
|
@@ -525,17 +691,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
525
691
|
_check_kwargs(kwargs)
|
|
526
692
|
config = dict(_config_level.get(level), **kwargs)
|
|
527
693
|
|
|
528
|
-
|
|
529
|
-
network.to_float(mstype.float16)
|
|
530
|
-
|
|
531
|
-
if config["keep_batchnorm_fp32"]:
|
|
532
|
-
_do_keep_batchnorm_fp32(network)
|
|
533
|
-
elif not config["keep_batchnorm_fp32"] and level == "O2":
|
|
534
|
-
network.to_float(mstype.float16)
|
|
535
|
-
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
536
|
-
pass
|
|
537
|
-
else:
|
|
538
|
-
network = auto_mixed_precision(network, level)
|
|
694
|
+
network = _auto_mixed_precision_process(network, config, level)
|
|
539
695
|
|
|
540
696
|
if loss_fn:
|
|
541
697
|
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
|
@@ -551,8 +707,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
551
707
|
raise ValueError("Only `loss_scale_manager=None` or "
|
|
552
708
|
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
|
|
553
709
|
"are supported on device `CPU`. ")
|
|
554
|
-
if _get_pipeline_stages() > 1:
|
|
555
|
-
network =
|
|
710
|
+
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
|
|
711
|
+
network = _TrainGradAccuWithLossScaleCell(network, optimizer,
|
|
556
712
|
scale_sense=update_cell).set_train()
|
|
557
713
|
elif enable_boost:
|
|
558
714
|
network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
|
|
@@ -561,8 +717,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
561
717
|
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
|
562
718
|
scale_sense=update_cell).set_train()
|
|
563
719
|
return network
|
|
564
|
-
if _get_pipeline_stages() > 1:
|
|
565
|
-
network =
|
|
720
|
+
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
|
|
721
|
+
network = _TrainGradAccuStepCell(network, optimizer).set_train()
|
|
566
722
|
elif enable_boost:
|
|
567
723
|
network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
|
|
568
724
|
else:
|
|
@@ -586,6 +742,23 @@ def get_white_list():
|
|
|
586
742
|
|
|
587
743
|
Returns:
|
|
588
744
|
list, A copy of internal white list.
|
|
745
|
+
|
|
746
|
+
Examples:
|
|
747
|
+
>>> from mindspore import amp
|
|
748
|
+
>>> white_list = amp.get_white_list()
|
|
749
|
+
>>> print(white_list)
|
|
750
|
+
[<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>,
|
|
751
|
+
<class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>,
|
|
752
|
+
<class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>,
|
|
753
|
+
<class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>,
|
|
754
|
+
<class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>,
|
|
755
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>,
|
|
756
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
|
|
757
|
+
<class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
|
|
758
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
|
|
759
|
+
<class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
|
|
760
|
+
<class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
|
|
761
|
+
<class 'mindspore.ops.operations.math_ops.Ger'>]
|
|
589
762
|
"""
|
|
590
763
|
white_list = AMP_WHITE_LIST.copy()
|
|
591
764
|
return white_list
|
|
@@ -602,24 +775,31 @@ def get_black_list():
|
|
|
602
775
|
|
|
603
776
|
Returns:
|
|
604
777
|
list, A copy of internal black list.
|
|
778
|
+
|
|
779
|
+
Examples:
|
|
780
|
+
>>> from mindspore import amp
|
|
781
|
+
>>> black_list = amp.get_black_list()
|
|
782
|
+
>>> print(black_list)
|
|
783
|
+
[<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
|
|
784
|
+
<class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]
|
|
605
785
|
"""
|
|
606
786
|
black_list = AMP_BLACK_LIST.copy()
|
|
607
787
|
return black_list
|
|
608
788
|
|
|
609
789
|
|
|
610
|
-
def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
790
|
+
def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
|
|
611
791
|
"""
|
|
612
792
|
Custom mixed precision by setting whitelist or blacklist.
|
|
613
793
|
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
|
|
614
|
-
When the `black_list` is provided, cells that are not in `black_list` will perform the pereision
|
|
615
|
-
conversion.
|
|
794
|
+
When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
|
|
616
795
|
Only one of `white_list` and `black_list` should be provided.
|
|
617
796
|
|
|
618
797
|
Note:
|
|
619
|
-
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
798
|
+
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
799
|
+
can result in a larger network hierarchy and slower performance.
|
|
800
|
+
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
801
|
+
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
802
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
623
803
|
- Primitives for blacklist is not support yet.
|
|
624
804
|
|
|
625
805
|
Args:
|
|
@@ -628,6 +808,8 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
|
628
808
|
white list is not used.
|
|
629
809
|
black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
|
|
630
810
|
black list is not used.
|
|
811
|
+
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
812
|
+
default: ``mstype.float16`` .
|
|
631
813
|
|
|
632
814
|
Returns:
|
|
633
815
|
network (Cell), A network supporting mixed precision.
|
|
@@ -635,12 +817,13 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
|
635
817
|
Raises:
|
|
636
818
|
TypeError: The network type is not Cell.
|
|
637
819
|
ValueError: Neither `white_list` nor `black_list` is provided.
|
|
820
|
+
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
638
821
|
ValueError: Both `white_list` and `black_list` are provided.
|
|
639
822
|
|
|
640
823
|
Examples:
|
|
641
824
|
>>> from mindspore import amp, nn
|
|
642
825
|
>>> # Define the network structure of LeNet5. Refer to
|
|
643
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
826
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
644
827
|
>>> net = LeNet5()
|
|
645
828
|
>>> custom_white_list = amp.get_white_list()
|
|
646
829
|
>>> custom_white_list.append(nn.Flatten)
|
|
@@ -656,13 +839,19 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
|
656
839
|
raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
|
|
657
840
|
"at the same time, please provide one or the other.")
|
|
658
841
|
|
|
842
|
+
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
843
|
+
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
844
|
+
|
|
659
845
|
if white_list is not None:
|
|
660
846
|
_list_check(white_list, "white_list")
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
847
|
+
network = _auto_white_list(network, white_list, dtype)
|
|
848
|
+
else:
|
|
849
|
+
_list_check(black_list, "black_list")
|
|
850
|
+
if MS_AMP_BY_REWRITE:
|
|
851
|
+
network = _auto_black_list_rewrite(network, black_list, dtype)
|
|
852
|
+
else:
|
|
853
|
+
network = _auto_black_list(network, black_list, dtype)
|
|
854
|
+
network = _OutputTo32(network)
|
|
666
855
|
return network
|
|
667
856
|
|
|
668
857
|
|
|
@@ -693,3 +882,14 @@ def _list_check(custom_list: list, list_name: str):
|
|
|
693
882
|
for elem in AMP_BLACK_LIST:
|
|
694
883
|
if elem not in custom_list:
|
|
695
884
|
logger.warning(f"{elem} is removed from internal black list.")
|
|
885
|
+
|
|
886
|
+
def _config_amp(*, enable_rewrite: bool = None, cast_op: type = None): # pylint: disable=unused-variable
|
|
887
|
+
"""Configure auto mixed precision."""
|
|
888
|
+
global MS_AMP_BY_REWRITE
|
|
889
|
+
global _amp_cast_op
|
|
890
|
+
|
|
891
|
+
if enable_rewrite is not None:
|
|
892
|
+
MS_AMP_BY_REWRITE = enable_rewrite
|
|
893
|
+
|
|
894
|
+
if cast_op is not None:
|
|
895
|
+
_amp_cast_op = cast_op
|