mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/train/amp.py
CHANGED
|
@@ -19,8 +19,8 @@ import mindspore as ms
|
|
|
19
19
|
from mindspore import nn
|
|
20
20
|
from mindspore import _checkparam as validator
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
|
-
from mindspore.nn.wrap.cell_wrapper import
|
|
23
|
-
from mindspore.nn.wrap.loss_scale import
|
|
22
|
+
from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
|
|
23
|
+
from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
|
|
24
24
|
from mindspore.ops import functional as F
|
|
25
25
|
from mindspore.parallel._utils import _get_pipeline_stages
|
|
26
26
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
|
@@ -30,9 +30,6 @@ from mindspore.ops import Primitive
|
|
|
30
30
|
from mindspore import log as logger
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
STREE = None
|
|
34
|
-
|
|
35
|
-
|
|
36
33
|
AMP_WHITE_LIST = [
|
|
37
34
|
nn.Conv1d,
|
|
38
35
|
nn.Conv2d,
|
|
@@ -64,17 +61,19 @@ AMP_BLACK_LIST = [
|
|
|
64
61
|
nn.LayerNorm
|
|
65
62
|
]
|
|
66
63
|
|
|
64
|
+
MS_AMP_BY_REWRITE = False
|
|
65
|
+
_amp_cast_op = P.Cast
|
|
67
66
|
|
|
68
67
|
class _OutputTo16(nn.Cell):
|
|
69
68
|
"""Wrap cell for amp. Cast network output back to float16."""
|
|
70
|
-
def __init__(self, backbone):
|
|
69
|
+
def __init__(self, backbone, dtype=mstype.float16):
|
|
71
70
|
super(_OutputTo16, self).__init__(auto_prefix=False)
|
|
72
71
|
self._backbone = backbone
|
|
73
|
-
|
|
74
|
-
|
|
72
|
+
self.dtype = dtype
|
|
73
|
+
self._get_attr_from_cell(backbone)
|
|
75
74
|
|
|
76
|
-
def construct(self,
|
|
77
|
-
return F.cast(self._backbone(
|
|
75
|
+
def construct(self, *args, **kwargs):
|
|
76
|
+
return F.cast(self._backbone(*args, **kwargs), self.dtype)
|
|
78
77
|
|
|
79
78
|
|
|
80
79
|
class _OutputTo32(nn.Cell):
|
|
@@ -82,68 +81,78 @@ class _OutputTo32(nn.Cell):
|
|
|
82
81
|
def __init__(self, backbone):
|
|
83
82
|
super(_OutputTo32, self).__init__(auto_prefix=False)
|
|
84
83
|
self._backbone = backbone
|
|
85
|
-
|
|
86
|
-
self._jit_config_dict = backbone.jit_config_dict
|
|
84
|
+
self._get_attr_from_cell(backbone)
|
|
87
85
|
|
|
88
|
-
def construct(self, *
|
|
89
|
-
out = self._backbone(*
|
|
86
|
+
def construct(self, *args, **kwargs):
|
|
87
|
+
out = self._backbone(*args, **kwargs)
|
|
90
88
|
return F.mixed_precision_cast(mstype.float32, out)
|
|
91
89
|
|
|
92
90
|
|
|
93
|
-
|
|
91
|
+
|
|
92
|
+
def _allow_mix_precision(node, allowed_list, dtype) -> bool:
|
|
94
93
|
"""
|
|
95
94
|
Check whether current node need do mix precision. Follow conditions need to be satisfied:
|
|
96
95
|
1) Type of node is one of (Primitive, nn.Cell)
|
|
97
|
-
2) Node is not
|
|
96
|
+
2) Node is not Cast Op
|
|
98
97
|
3) to_float(mindspore.float16) is not set in Cell
|
|
99
98
|
"""
|
|
100
|
-
|
|
99
|
+
node_inst = node.get_instance()
|
|
100
|
+
if node_inst in allowed_list:
|
|
101
101
|
return True
|
|
102
|
+
if node.get_targets() is None:
|
|
103
|
+
return False
|
|
102
104
|
if not issubclass(node.get_instance_type(), (Primitive, nn.Cell)):
|
|
103
105
|
return False
|
|
104
|
-
if isinstance(
|
|
106
|
+
if isinstance(node_inst, _amp_cast_op):
|
|
105
107
|
return False
|
|
106
108
|
if issubclass(node.get_instance_type(), nn.Cell):
|
|
107
|
-
# if cell is already in allowed_list, it means to_float(
|
|
108
|
-
# if cell is not in allowed_list, but has to_float(
|
|
109
|
-
# it means to_float(
|
|
110
|
-
if
|
|
109
|
+
# if cell is already in allowed_list, it means to_float() is set by amp.
|
|
110
|
+
# if cell is not in allowed_list, but has to_float(),
|
|
111
|
+
# it means to_float() is set by user.
|
|
112
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
113
|
+
if hasattr(node_inst, to_float_flag) and getattr(node_inst, to_float_flag):
|
|
111
114
|
return False
|
|
112
115
|
allowed_list.append(node.get_instance())
|
|
113
116
|
return True
|
|
114
117
|
|
|
115
118
|
|
|
116
|
-
def _insert_cast_operator_process(node,
|
|
119
|
+
def _insert_cast_operator_process(node, dtype):
|
|
117
120
|
"""insert cast for operators in white_list."""
|
|
121
|
+
dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
|
|
118
122
|
new_cast_node = None
|
|
119
|
-
|
|
123
|
+
stree = node.get_symbol_tree()
|
|
124
|
+
# insert cast fp16/bf16 before the primitive operators
|
|
120
125
|
if issubclass(node.get_instance_type(), Primitive):
|
|
121
|
-
for idx in
|
|
126
|
+
for idx, arg in enumerate(node.get_args()):
|
|
122
127
|
position = stree.before(node)
|
|
123
|
-
new_node =
|
|
124
|
-
|
|
125
|
-
|
|
128
|
+
new_node = _amp_cast_op()
|
|
129
|
+
cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, ""])
|
|
130
|
+
arg_provider = node.get_handler().get_arg_providers()[idx]
|
|
131
|
+
if arg_provider and len(arg_provider[0].get_target_users(arg_provider[1])) > 1:
|
|
132
|
+
cast_targets = [stree.unique_name(str(arg))]
|
|
133
|
+
else:
|
|
134
|
+
cast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
|
|
126
135
|
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
127
|
-
targets=
|
|
128
|
-
args=
|
|
136
|
+
targets=cast_targets,
|
|
137
|
+
args=cast_args,
|
|
129
138
|
name='incast_{}{}'.format(node.get_name(), idx))
|
|
130
139
|
stree.insert(position, new_cast_node)
|
|
131
140
|
node.set_arg_by_node(idx, new_cast_node)
|
|
132
|
-
# insert cast
|
|
141
|
+
# insert cast fp16/bf16 before the Cell operators
|
|
133
142
|
elif issubclass(node.get_instance_type(), nn.Cell):
|
|
134
|
-
node.get_instance().to_float(
|
|
143
|
+
node.get_instance().to_float(dtype)
|
|
135
144
|
# ignore if subclass is not one of (Primitive, nn.Cell)
|
|
136
145
|
else:
|
|
137
146
|
return
|
|
138
147
|
|
|
139
148
|
# insert cast float32 after the operators
|
|
140
149
|
position = stree.after(node)
|
|
141
|
-
new_node =
|
|
142
|
-
|
|
143
|
-
|
|
150
|
+
new_node = _amp_cast_op()
|
|
151
|
+
cast_args = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
|
|
152
|
+
"mindspore.float32"])
|
|
144
153
|
new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
|
|
145
|
-
targets=[
|
|
146
|
-
args=
|
|
154
|
+
targets=[node.get_targets()[0]],
|
|
155
|
+
args=cast_args,
|
|
147
156
|
name='outcast_{}'.format(node.get_name()))
|
|
148
157
|
# insert node & unique names
|
|
149
158
|
stree.insert(position, new_cast_node)
|
|
@@ -156,43 +165,102 @@ def _insert_cast_operator_process(node, stree):
|
|
|
156
165
|
user.set_arg_by_node(idx, new_cast_node)
|
|
157
166
|
|
|
158
167
|
|
|
159
|
-
def _insert_cast_operator_white_list(stree, white_list):
|
|
168
|
+
def _insert_cast_operator_white_list(stree, white_list, dtype):
|
|
160
169
|
"""insert cast for operators in white_list."""
|
|
161
170
|
allowed_list = []
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
171
|
+
# Ignore if net called ".to_float(dtype)"
|
|
172
|
+
net = stree.get_handler().get_origin_network()
|
|
173
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
174
|
+
if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
|
|
175
|
+
return
|
|
176
|
+
node_list = []
|
|
177
|
+
node_list.extend(list(stree.nodes()))
|
|
178
|
+
while node_list:
|
|
179
|
+
node = node_list.pop()
|
|
165
180
|
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
181
|
+
if MS_AMP_BY_REWRITE:
|
|
182
|
+
_insert_cast_for_cell_container(node, dtype, allowed_list, white_list=white_list)
|
|
166
183
|
for n in node.get_handler().node_list:
|
|
167
184
|
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
168
185
|
_insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
|
|
169
|
-
white_list)
|
|
186
|
+
white_list, dtype)
|
|
170
187
|
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
171
188
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
172
|
-
_insert_cast_operator_white_list(substree, white_list)
|
|
173
|
-
elif node.
|
|
174
|
-
|
|
189
|
+
_insert_cast_operator_white_list(substree, white_list, dtype)
|
|
190
|
+
elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
|
|
191
|
+
if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
|
|
192
|
+
nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
|
|
193
|
+
node_list.extend(nodes)
|
|
194
|
+
elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list, dtype):
|
|
195
|
+
_insert_cast_operator_process(node, dtype)
|
|
175
196
|
|
|
176
197
|
|
|
177
|
-
def
|
|
198
|
+
def _insert_cast_for_cell_container(cell_container, dtype, allowed_list, *, white_list=None, black_list=None):
|
|
199
|
+
"""
|
|
200
|
+
Insert cast for cell containers.
|
|
201
|
+
Only one of white_list and black_list can be set.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
class CastNet(nn.Cell):
|
|
205
|
+
"""Cast net"""
|
|
206
|
+
def __init__(self, dtype):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.cast = _amp_cast_op()
|
|
209
|
+
self.dtype = dtype
|
|
210
|
+
|
|
211
|
+
def construct(self, x):
|
|
212
|
+
return self.cast(x, self.dtype)
|
|
213
|
+
|
|
214
|
+
cast_flag = False
|
|
215
|
+
current_node = None
|
|
216
|
+
stree = cell_container.get_symbol_tree()
|
|
217
|
+
for node in cell_container.get_handler().nodes():
|
|
218
|
+
current_node = ms.rewrite.Node(node)
|
|
219
|
+
if (white_list is not None and current_node.get_instance_type() in white_list) or \
|
|
220
|
+
(black_list is not None and current_node.get_instance_type() not in black_list) and \
|
|
221
|
+
(_allow_mix_precision(current_node, allowed_list, dtype)):
|
|
222
|
+
cast_flag = True
|
|
223
|
+
current_node.get_instance().to_float(dtype)
|
|
224
|
+
elif cast_flag:
|
|
225
|
+
# cast next node back to float32
|
|
226
|
+
current_node.get_instance().to_float(mstype.float32)
|
|
227
|
+
cast_flag = False
|
|
228
|
+
if cast_flag and current_node:
|
|
229
|
+
# if last node in cell_container is casted to fp16/bf16, insert a cast node to cast value back to fp32
|
|
230
|
+
cast_node = ms.rewrite.Node.create_call_cell(cell=CastNet(mstype.float32),
|
|
231
|
+
args=[current_node.get_targets()[0]],
|
|
232
|
+
targets=[current_node.get_targets()[0]],
|
|
233
|
+
name=f"outcast_{cell_container.get_name()}")
|
|
234
|
+
stree.insert(stree.after(current_node), cast_node)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _need_removed_cast_pair(node, dtype):
|
|
178
238
|
"""check whether the cast pairs should be removed."""
|
|
179
|
-
|
|
239
|
+
dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
|
|
240
|
+
cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "mindspore.float32"])
|
|
180
241
|
cast_dtype_f16 = cast_dtypes[0]
|
|
181
242
|
cast_dtype_f32 = cast_dtypes[1]
|
|
182
|
-
# current node should be
|
|
183
|
-
if node.get_instance_type() !=
|
|
243
|
+
# current node should be Cast Op to float32
|
|
244
|
+
if node.get_instance_type() != _amp_cast_op:
|
|
184
245
|
return False
|
|
185
246
|
node_cast_type = node.get_args()[1]
|
|
186
247
|
if node_cast_type != cast_dtype_f32:
|
|
187
248
|
return False
|
|
188
|
-
# all user nodes should be
|
|
249
|
+
# all user nodes should be Cast Op to dtype or Cell with to_float(dtype)
|
|
189
250
|
if not node.get_users():
|
|
190
251
|
return False
|
|
252
|
+
all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
|
|
191
253
|
for user in node.get_users():
|
|
254
|
+
# If ControlFlow node(if statement) exists between current node and user node,
|
|
255
|
+
# cast pair should not be removed.
|
|
256
|
+
middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
|
|
257
|
+
if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
|
|
258
|
+
return False
|
|
192
259
|
if isinstance(user.get_instance(), nn.Cell):
|
|
193
|
-
if
|
|
260
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
261
|
+
if not (hasattr(user.get_instance(), to_float_flag) and getattr(user.get_instance(), to_float_flag)):
|
|
194
262
|
return False
|
|
195
|
-
elif user.get_instance_type() ==
|
|
263
|
+
elif user.get_instance_type() == _amp_cast_op:
|
|
196
264
|
user_cast_type = user.get_args()[1]
|
|
197
265
|
if user_cast_type != cast_dtype_f16:
|
|
198
266
|
return False
|
|
@@ -201,18 +269,20 @@ def _need_removed_cast_pair(node):
|
|
|
201
269
|
return True
|
|
202
270
|
|
|
203
271
|
|
|
204
|
-
def _removed_cast_pair_process(
|
|
272
|
+
def _removed_cast_pair_process(cast_f32_node):
|
|
205
273
|
"""remove the duplicated cast operators."""
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
274
|
+
stree = cast_f32_node.get_symbol_tree()
|
|
275
|
+
cast_f32_users = cast_f32_node.get_users()
|
|
276
|
+
# remove cast f16 nodes
|
|
277
|
+
for user_node in cast_f32_users:
|
|
278
|
+
if user_node.get_instance_type() == _amp_cast_op:
|
|
209
279
|
cast_f16_node = user_node
|
|
210
280
|
# modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
|
|
211
281
|
for cast_f16_user in cast_f16_node.get_users():
|
|
212
282
|
for idx, arg in enumerate(cast_f16_user.get_args()):
|
|
213
283
|
if arg == cast_f16_node.get_targets()[0]:
|
|
214
284
|
cast_f16_user.set_arg(idx, cast_f32_node.get_args()[0])
|
|
215
|
-
stree.
|
|
285
|
+
stree.erase(cast_f16_node)
|
|
216
286
|
# update args of cell f16 nodes
|
|
217
287
|
elif isinstance(user_node.get_instance(), nn.Cell):
|
|
218
288
|
cell_f16_node = user_node
|
|
@@ -220,37 +290,81 @@ def _removed_cast_pair_process(stree, cast_f32_node):
|
|
|
220
290
|
if arg == cast_f32_node.get_targets()[0]:
|
|
221
291
|
cell_f16_node.set_arg(idx, cast_f32_node.get_args()[0])
|
|
222
292
|
# remove the cast f32 node
|
|
223
|
-
stree.
|
|
293
|
+
stree.erase(cast_f32_node)
|
|
224
294
|
|
|
225
295
|
|
|
226
|
-
def _remove_duplicated_cast(stree):
|
|
296
|
+
def _remove_duplicated_cast(stree, dtype):
|
|
227
297
|
"""remove the duplicated cast operators."""
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
298
|
+
node_list = []
|
|
299
|
+
node_list.extend(list(stree.nodes()))
|
|
300
|
+
while node_list:
|
|
301
|
+
node = node_list.pop()
|
|
231
302
|
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
232
303
|
for n in node.get_handler().node_list:
|
|
233
304
|
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
234
|
-
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
|
|
305
|
+
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), dtype)
|
|
235
306
|
elif node.get_node_type() == ms.rewrite.NodeType.Tree:
|
|
236
307
|
substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
|
|
237
|
-
_remove_duplicated_cast(substree)
|
|
238
|
-
elif
|
|
239
|
-
|
|
308
|
+
_remove_duplicated_cast(substree, dtype)
|
|
309
|
+
elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
|
|
310
|
+
if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
|
|
311
|
+
nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
|
|
312
|
+
node_list.extend(nodes)
|
|
313
|
+
elif _need_removed_cast_pair(node, dtype):
|
|
314
|
+
_removed_cast_pair_process(node)
|
|
240
315
|
|
|
241
316
|
|
|
242
|
-
def _auto_white_list(network, white_list):
|
|
317
|
+
def _auto_white_list(network, white_list, dtype):
|
|
243
318
|
"""process the white list of network."""
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
return STREE.get_network()
|
|
319
|
+
stree = ms.rewrite.SymbolTree.create(network)
|
|
320
|
+
_insert_cast_operator_white_list(stree, white_list, dtype)
|
|
321
|
+
_remove_duplicated_cast(stree, dtype)
|
|
322
|
+
return stree.get_network()
|
|
249
323
|
|
|
250
324
|
|
|
251
|
-
def
|
|
325
|
+
def _insert_cast_operator_black_list(stree, black_list, dtype):
|
|
326
|
+
"""insert cast for operators not in black_list."""
|
|
327
|
+
allowed_list = []
|
|
328
|
+
# Ignore if net called ".to_float(dtype)"
|
|
329
|
+
net = stree.get_handler().get_origin_network()
|
|
330
|
+
to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
|
|
331
|
+
if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
|
|
332
|
+
return
|
|
333
|
+
for node in stree.nodes(all_nodes=True):
|
|
334
|
+
if node.get_targets() is None:
|
|
335
|
+
continue
|
|
336
|
+
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
|
337
|
+
_insert_cast_for_cell_container(node, dtype, allowed_list, black_list=black_list)
|
|
338
|
+
elif isinstance(node.get_handler().get_node_manager(), ms.rewrite.node.CellContainer):
|
|
339
|
+
# nodes in CellContainer are processed by _insert_cast_for_cell_container
|
|
340
|
+
continue
|
|
341
|
+
elif node.get_instance_type() not in black_list and _allow_mix_precision(node, allowed_list, dtype):
|
|
342
|
+
_insert_cast_operator_process(node, dtype)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _remove_duplicated_cast_rewrite(stree, dtype):
|
|
346
|
+
"""remove the duplicated cast operators."""
|
|
347
|
+
for node in stree.nodes(all_nodes=True):
|
|
348
|
+
if _need_removed_cast_pair(node, dtype):
|
|
349
|
+
user_nodes = node.get_users()
|
|
350
|
+
# remove cast f16 nodes
|
|
351
|
+
for user_node in user_nodes:
|
|
352
|
+
if user_node.get_instance_type() == _amp_cast_op:
|
|
353
|
+
stree.erase(user_node)
|
|
354
|
+
# remove the cast f32 node
|
|
355
|
+
stree.erase(node)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _auto_black_list_rewrite(network, black_list, dtype):
|
|
359
|
+
stree = ms.rewrite.SymbolTree.create(network)
|
|
360
|
+
_insert_cast_operator_black_list(stree, black_list, dtype)
|
|
361
|
+
_remove_duplicated_cast_rewrite(stree, dtype)
|
|
362
|
+
return stree.get_network()
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _auto_black_list(network, black_list, dtype):
|
|
252
366
|
"""process the black list of network."""
|
|
253
|
-
network.to_float(
|
|
367
|
+
network.to_float(dtype)
|
|
254
368
|
cells = network.name_cells()
|
|
255
369
|
change = False
|
|
256
370
|
for name in cells:
|
|
@@ -258,32 +372,76 @@ def _auto_black_list(network, black_list):
|
|
|
258
372
|
if subcell == network:
|
|
259
373
|
continue
|
|
260
374
|
if isinstance(subcell, tuple(black_list)):
|
|
261
|
-
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
|
|
375
|
+
network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
|
|
262
376
|
change = True
|
|
263
377
|
else:
|
|
264
|
-
_auto_black_list(subcell, black_list)
|
|
378
|
+
_auto_black_list(subcell, black_list, dtype)
|
|
265
379
|
if isinstance(network, nn.SequentialCell) and change:
|
|
266
380
|
network.cell_list = list(network.cells())
|
|
381
|
+
return network
|
|
267
382
|
|
|
268
383
|
|
|
269
|
-
def auto_mixed_precision(network, amp_level="O0"):
|
|
384
|
+
def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
270
385
|
"""
|
|
271
|
-
auto mixed precision
|
|
386
|
+
Returns a network processed with auto mixed precision.
|
|
387
|
+
|
|
388
|
+
This interface will automatically perform mixed-precision processing on the input network, and the cells
|
|
389
|
+
and operators in the processed network will add precision conversion operations to calculate with lower
|
|
390
|
+
precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are
|
|
391
|
+
converted to lower precision float, and calculation results are converted back to full precision float,
|
|
392
|
+
i.e. ``mstype.float32`` .
|
|
393
|
+
|
|
394
|
+
The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and
|
|
395
|
+
operators are specifically converted.
|
|
396
|
+
|
|
397
|
+
The current built-in whitelist contents are:
|
|
398
|
+
|
|
399
|
+
[:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
|
|
400
|
+
:class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
|
|
401
|
+
:class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
|
|
402
|
+
:class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
|
|
403
|
+
:class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
|
|
404
|
+
:class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
|
|
405
|
+
:class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
|
|
406
|
+
|
|
407
|
+
The current built-in blacklist contents are:
|
|
408
|
+
|
|
409
|
+
[:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
|
|
410
|
+
:class:`mindspore.nn.LayerNorm`]
|
|
411
|
+
|
|
412
|
+
For details on automatic mixed precision, refer to
|
|
413
|
+
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.2/advanced/mixed_precision.html>`_ .
|
|
414
|
+
|
|
415
|
+
Note:
|
|
416
|
+
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
417
|
+
can result in a larger network hierarchy and slower performance.
|
|
418
|
+
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
419
|
+
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
420
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
272
421
|
|
|
273
422
|
Args:
|
|
274
423
|
network (Cell): Definition of the network.
|
|
275
|
-
amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: "O0".
|
|
424
|
+
amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` .
|
|
276
425
|
|
|
277
426
|
- "O0": Do not change.
|
|
278
|
-
- "O1":
|
|
279
|
-
|
|
280
|
-
- "
|
|
427
|
+
- "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
|
|
428
|
+
precision operations for the rest.
|
|
429
|
+
- "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
|
|
430
|
+
to lower precision operations.
|
|
431
|
+
- "O3": Cast network to lower precision.
|
|
432
|
+
|
|
433
|
+
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
434
|
+
default: ``mstype.float16`` .
|
|
281
435
|
|
|
282
436
|
Raises:
|
|
283
|
-
|
|
437
|
+
TypeError: If `network` is not a Cell.
|
|
438
|
+
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
439
|
+
ValueError: If `amp_level` is not within the supported range.
|
|
284
440
|
|
|
285
441
|
Examples:
|
|
286
|
-
>>> from mindspore import amp
|
|
442
|
+
>>> from mindspore import amp
|
|
443
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
444
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
287
445
|
>>> network = LeNet5()
|
|
288
446
|
>>> amp_level = "O1"
|
|
289
447
|
>>> net = amp.auto_mixed_precision(network, amp_level)
|
|
@@ -291,18 +449,37 @@ def auto_mixed_precision(network, amp_level="O0"):
|
|
|
291
449
|
if not isinstance(network, nn.Cell):
|
|
292
450
|
raise TypeError("The network type should be Cell.")
|
|
293
451
|
|
|
452
|
+
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
453
|
+
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
454
|
+
|
|
294
455
|
if amp_level == "O0":
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
456
|
+
return network
|
|
457
|
+
|
|
458
|
+
# Return network if the same amp level has already been configurated
|
|
459
|
+
if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
|
|
460
|
+
logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
|
|
461
|
+
f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
|
|
462
|
+
f"degradation.")
|
|
463
|
+
|
|
464
|
+
if amp_level == "O1":
|
|
465
|
+
network = _auto_white_list(network, AMP_WHITE_LIST, dtype)
|
|
298
466
|
elif amp_level == "O2":
|
|
299
|
-
|
|
467
|
+
if MS_AMP_BY_REWRITE:
|
|
468
|
+
network = _auto_black_list_rewrite(network, AMP_BLACK_LIST, dtype)
|
|
469
|
+
else:
|
|
470
|
+
network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
|
|
471
|
+
network = _OutputTo32(network)
|
|
300
472
|
elif amp_level == "O3":
|
|
301
|
-
|
|
473
|
+
if MS_AMP_BY_REWRITE:
|
|
474
|
+
network = _auto_black_list_rewrite(network, [], dtype)
|
|
475
|
+
else:
|
|
476
|
+
network.to_float(dtype)
|
|
477
|
+
network = _OutputTo32(network)
|
|
302
478
|
else:
|
|
303
479
|
raise ValueError("The amp level {} is not supported".format(amp_level))
|
|
304
|
-
|
|
305
|
-
|
|
480
|
+
|
|
481
|
+
setattr(network, "_amp_level", amp_level)
|
|
482
|
+
|
|
306
483
|
return network
|
|
307
484
|
|
|
308
485
|
|
|
@@ -393,8 +570,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
393
570
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
394
571
|
self._backbone = backbone
|
|
395
572
|
self._loss_fn = loss_fn
|
|
396
|
-
|
|
397
|
-
self._jit_config_dict = backbone.jit_config_dict
|
|
573
|
+
self._get_attr_from_cell(backbone)
|
|
398
574
|
|
|
399
575
|
def construct(self, data, label):
|
|
400
576
|
out = self._backbone(data)
|
|
@@ -409,42 +585,80 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|
|
409
585
|
return network
|
|
410
586
|
|
|
411
587
|
|
|
588
|
+
def _is_grad_accumulation(mcell):
|
|
589
|
+
if mcell.cls_name == "GradAccumulationCell":
|
|
590
|
+
return True
|
|
591
|
+
for cell in mcell.cells():
|
|
592
|
+
if _is_grad_accumulation(cell):
|
|
593
|
+
return True
|
|
594
|
+
return False
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def _auto_mixed_precision_process(network, config, level):
|
|
598
|
+
"""Auto mixed precision process."""
|
|
599
|
+
if MS_AMP_BY_REWRITE:
|
|
600
|
+
if config["cast_model_type"] == mstype.float16 or level == "O2":
|
|
601
|
+
level = "O2" if config["keep_batchnorm_fp32"] else "O3"
|
|
602
|
+
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
603
|
+
# cast_model_type set by kwargs
|
|
604
|
+
level = "O0"
|
|
605
|
+
network = auto_mixed_precision(network, level)
|
|
606
|
+
else:
|
|
607
|
+
if config["cast_model_type"] == mstype.float16:
|
|
608
|
+
network.to_float(mstype.float16)
|
|
609
|
+
|
|
610
|
+
if config["keep_batchnorm_fp32"]:
|
|
611
|
+
_do_keep_batchnorm_fp32(network)
|
|
612
|
+
elif not config["keep_batchnorm_fp32"] and level == "O2":
|
|
613
|
+
network.to_float(mstype.float16)
|
|
614
|
+
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
615
|
+
pass
|
|
616
|
+
else:
|
|
617
|
+
network = auto_mixed_precision(network, level)
|
|
618
|
+
return network
|
|
619
|
+
|
|
620
|
+
|
|
412
621
|
def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
|
|
413
622
|
"""
|
|
414
623
|
Build the mixed precision training cell automatically.
|
|
415
624
|
|
|
625
|
+
Note:
|
|
626
|
+
- After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
|
|
627
|
+
to perform the precision conversion again. If `build_train_network` is used to train a converted network,
|
|
628
|
+
`level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
629
|
+
|
|
416
630
|
Args:
|
|
417
631
|
network (Cell): Definition of the network.
|
|
632
|
+
optimizer (:class:`mindspore.nn.Optimizer`): Define the optimizer to update the Parameter.
|
|
418
633
|
loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
|
|
419
|
-
Default: None.
|
|
420
|
-
|
|
421
|
-
level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
|
|
634
|
+
Default: ``None`` .
|
|
635
|
+
level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
|
|
422
636
|
|
|
423
|
-
-
|
|
424
|
-
-
|
|
637
|
+
- 'O0': Do not change.
|
|
638
|
+
- 'O1': Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
|
425
639
|
The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
|
|
426
640
|
Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
|
|
427
|
-
-
|
|
641
|
+
- 'O2': Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
|
428
642
|
using dynamic loss scale.
|
|
429
|
-
-
|
|
430
|
-
- auto: Set to level to recommended level in different devices. Set level to
|
|
431
|
-
level to
|
|
643
|
+
- 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
|
|
644
|
+
- 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set
|
|
645
|
+
level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all
|
|
432
646
|
scenarios. User should specify the level for special network.
|
|
433
647
|
|
|
434
|
-
|
|
648
|
+
'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`,
|
|
435
649
|
`cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in
|
|
436
650
|
`kwargs`.
|
|
437
651
|
|
|
438
652
|
boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
|
|
439
|
-
training. Supports [
|
|
653
|
+
training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
|
|
440
654
|
|
|
441
|
-
-
|
|
442
|
-
-
|
|
655
|
+
- 'O0': Do not change.
|
|
656
|
+
- 'O1': Enable the boost mode, the performance is improved by about 20%, and
|
|
443
657
|
the accuracy is the same as the original accuracy.
|
|
444
|
-
-
|
|
658
|
+
- 'O2': Enable the boost mode, the performance is improved by about 30%, and
|
|
445
659
|
the accuracy is reduced by less than 3%.
|
|
446
660
|
|
|
447
|
-
If
|
|
661
|
+
If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.
|
|
448
662
|
|
|
449
663
|
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the
|
|
450
664
|
network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted
|
|
@@ -461,6 +675,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
461
675
|
|
|
462
676
|
Examples:
|
|
463
677
|
>>> from mindspore import amp, nn
|
|
678
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
679
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
464
680
|
>>> network = LeNet5()
|
|
465
681
|
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
466
682
|
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
|
|
@@ -475,22 +691,12 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
475
691
|
_check_kwargs(kwargs)
|
|
476
692
|
config = dict(_config_level.get(level), **kwargs)
|
|
477
693
|
|
|
478
|
-
|
|
479
|
-
network.to_float(mstype.float16)
|
|
480
|
-
|
|
481
|
-
if config["keep_batchnorm_fp32"]:
|
|
482
|
-
_do_keep_batchnorm_fp32(network)
|
|
483
|
-
elif not config["keep_batchnorm_fp32"] and level == "O2":
|
|
484
|
-
network.to_float(mstype.float16)
|
|
485
|
-
elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
|
|
486
|
-
pass
|
|
487
|
-
else:
|
|
488
|
-
network = auto_mixed_precision(network, level)
|
|
694
|
+
network = _auto_mixed_precision_process(network, config, level)
|
|
489
695
|
|
|
490
696
|
if loss_fn:
|
|
491
697
|
network = _add_loss_network(network, loss_fn, config["cast_model_type"])
|
|
492
698
|
|
|
493
|
-
loss_scale =
|
|
699
|
+
loss_scale = None
|
|
494
700
|
if config["loss_scale_manager"] is not None:
|
|
495
701
|
loss_scale_manager = config["loss_scale_manager"]
|
|
496
702
|
loss_scale = loss_scale_manager.get_loss_scale()
|
|
@@ -501,8 +707,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
501
707
|
raise ValueError("Only `loss_scale_manager=None` or "
|
|
502
708
|
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
|
|
503
709
|
"are supported on device `CPU`. ")
|
|
504
|
-
if _get_pipeline_stages() > 1:
|
|
505
|
-
network =
|
|
710
|
+
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
|
|
711
|
+
network = _TrainGradAccuWithLossScaleCell(network, optimizer,
|
|
506
712
|
scale_sense=update_cell).set_train()
|
|
507
713
|
elif enable_boost:
|
|
508
714
|
network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
|
|
@@ -511,8 +717,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
511
717
|
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
|
512
718
|
scale_sense=update_cell).set_train()
|
|
513
719
|
return network
|
|
514
|
-
if _get_pipeline_stages() > 1:
|
|
515
|
-
network =
|
|
720
|
+
if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
|
|
721
|
+
network = _TrainGradAccuStepCell(network, optimizer).set_train()
|
|
516
722
|
elif enable_boost:
|
|
517
723
|
network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
|
|
518
724
|
else:
|
|
@@ -524,11 +730,35 @@ def get_white_list():
|
|
|
524
730
|
"""
|
|
525
731
|
Provide a copy of internal white list used by auto mixed precision.
|
|
526
732
|
|
|
527
|
-
|
|
528
|
-
|
|
733
|
+
The current built-in whitelist contents are:
|
|
734
|
+
|
|
735
|
+
[:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
|
|
736
|
+
:class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
|
|
737
|
+
:class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
|
|
738
|
+
:class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
|
|
739
|
+
:class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
|
|
740
|
+
:class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
|
|
741
|
+
:class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
|
|
529
742
|
|
|
530
743
|
Returns:
|
|
531
744
|
list, A copy of internal white list.
|
|
745
|
+
|
|
746
|
+
Examples:
|
|
747
|
+
>>> from mindspore import amp
|
|
748
|
+
>>> white_list = amp.get_white_list()
|
|
749
|
+
>>> print(white_list)
|
|
750
|
+
[<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>,
|
|
751
|
+
<class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>,
|
|
752
|
+
<class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>,
|
|
753
|
+
<class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>,
|
|
754
|
+
<class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>,
|
|
755
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>,
|
|
756
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
|
|
757
|
+
<class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
|
|
758
|
+
<class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
|
|
759
|
+
<class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
|
|
760
|
+
<class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
|
|
761
|
+
<class 'mindspore.ops.operations.math_ops.Ger'>]
|
|
532
762
|
"""
|
|
533
763
|
white_list = AMP_WHITE_LIST.copy()
|
|
534
764
|
return white_list
|
|
@@ -538,39 +768,48 @@ def get_black_list():
|
|
|
538
768
|
"""
|
|
539
769
|
Provide a copy of internal black list used by auto mixed precision.
|
|
540
770
|
|
|
541
|
-
|
|
542
|
-
|
|
771
|
+
The current built-in blacklist contents are:
|
|
772
|
+
|
|
773
|
+
[:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
|
|
774
|
+
:class:`mindspore.nn.LayerNorm`]
|
|
543
775
|
|
|
544
776
|
Returns:
|
|
545
777
|
list, A copy of internal black list.
|
|
778
|
+
|
|
779
|
+
Examples:
|
|
780
|
+
>>> from mindspore import amp
|
|
781
|
+
>>> black_list = amp.get_black_list()
|
|
782
|
+
>>> print(black_list)
|
|
783
|
+
[<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
|
|
784
|
+
<class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]
|
|
546
785
|
"""
|
|
547
786
|
black_list = AMP_BLACK_LIST.copy()
|
|
548
787
|
return black_list
|
|
549
788
|
|
|
550
789
|
|
|
551
|
-
def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
790
|
+
def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
|
|
552
791
|
"""
|
|
553
792
|
Custom mixed precision by setting whitelist or blacklist.
|
|
554
793
|
When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
|
|
555
|
-
When the `black_list` is provided,
|
|
556
|
-
conversion.
|
|
794
|
+
When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
|
|
557
795
|
Only one of `white_list` and `black_list` should be provided.
|
|
558
796
|
|
|
559
|
-
.. warning::
|
|
560
|
-
This is an experimental API that is subject to change or deletion.
|
|
561
|
-
|
|
562
797
|
Note:
|
|
563
|
-
-
|
|
564
|
-
|
|
565
|
-
|
|
798
|
+
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
799
|
+
can result in a larger network hierarchy and slower performance.
|
|
800
|
+
- If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
|
|
801
|
+
mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
|
|
802
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
566
803
|
- Primitives for blacklist is not support yet.
|
|
567
804
|
|
|
568
805
|
Args:
|
|
569
806
|
network (Cell): Definition of the network.
|
|
570
|
-
white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: None, means
|
|
807
|
+
white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means
|
|
571
808
|
white list is not used.
|
|
572
|
-
black_list (list[
|
|
809
|
+
black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
|
|
573
810
|
black list is not used.
|
|
811
|
+
dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
|
|
812
|
+
default: ``mstype.float16`` .
|
|
574
813
|
|
|
575
814
|
Returns:
|
|
576
815
|
network (Cell), A network supporting mixed precision.
|
|
@@ -578,13 +817,16 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
|
578
817
|
Raises:
|
|
579
818
|
TypeError: The network type is not Cell.
|
|
580
819
|
ValueError: Neither `white_list` nor `black_list` is provided.
|
|
820
|
+
ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
|
|
581
821
|
ValueError: Both `white_list` and `black_list` are provided.
|
|
582
822
|
|
|
583
823
|
Examples:
|
|
584
|
-
>>> from mindspore import amp
|
|
585
|
-
>>>
|
|
824
|
+
>>> from mindspore import amp, nn
|
|
825
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
826
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
827
|
+
>>> net = LeNet5()
|
|
586
828
|
>>> custom_white_list = amp.get_white_list()
|
|
587
|
-
>>> custom_white_list.append(nn.
|
|
829
|
+
>>> custom_white_list.append(nn.Flatten)
|
|
588
830
|
>>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)
|
|
589
831
|
"""
|
|
590
832
|
if not isinstance(network, nn.Cell):
|
|
@@ -597,13 +839,19 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
|
|
|
597
839
|
raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
|
|
598
840
|
"at the same time, please provide one or the other.")
|
|
599
841
|
|
|
842
|
+
if dtype not in (mstype.float16, mstype.bfloat16):
|
|
843
|
+
raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
|
|
844
|
+
|
|
600
845
|
if white_list is not None:
|
|
601
846
|
_list_check(white_list, "white_list")
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
847
|
+
network = _auto_white_list(network, white_list, dtype)
|
|
848
|
+
else:
|
|
849
|
+
_list_check(black_list, "black_list")
|
|
850
|
+
if MS_AMP_BY_REWRITE:
|
|
851
|
+
network = _auto_black_list_rewrite(network, black_list, dtype)
|
|
852
|
+
else:
|
|
853
|
+
network = _auto_black_list(network, black_list, dtype)
|
|
854
|
+
network = _OutputTo32(network)
|
|
607
855
|
return network
|
|
608
856
|
|
|
609
857
|
|
|
@@ -623,11 +871,25 @@ def _list_check(custom_list: list, list_name: str):
|
|
|
623
871
|
if not isinstance(elem, type):
|
|
624
872
|
raise TypeError(f"The element in {list_name} should be a class, but got {elem}")
|
|
625
873
|
|
|
626
|
-
if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
|
|
874
|
+
if list_name == "white_list" and not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
|
|
627
875
|
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', "
|
|
628
876
|
f"but got {elem}")
|
|
629
877
|
|
|
878
|
+
if list_name == "black_list" and not issubclass(elem, nn.Cell):
|
|
879
|
+
raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell', but got {elem}")
|
|
880
|
+
|
|
630
881
|
if list_name == 'black_list':
|
|
631
882
|
for elem in AMP_BLACK_LIST:
|
|
632
883
|
if elem not in custom_list:
|
|
633
884
|
logger.warning(f"{elem} is removed from internal black list.")
|
|
885
|
+
|
|
886
|
+
def _config_amp(*, enable_rewrite: bool = None, cast_op: type = None): # pylint: disable=unused-variable
|
|
887
|
+
"""Configure auto mixed precision."""
|
|
888
|
+
global MS_AMP_BY_REWRITE
|
|
889
|
+
global _amp_cast_op
|
|
890
|
+
|
|
891
|
+
if enable_rewrite is not None:
|
|
892
|
+
MS_AMP_BY_REWRITE = enable_rewrite
|
|
893
|
+
|
|
894
|
+
if cast_op is not None:
|
|
895
|
+
_amp_cast_op = cast_op
|