mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -95,7 +95,12 @@ class WithLossCell(Cell):
|
|
|
95
95
|
``Ascend`` ``GPU`` ``CPU``
|
|
96
96
|
|
|
97
97
|
Examples:
|
|
98
|
-
>>>
|
|
98
|
+
>>> import mindspore as ms
|
|
99
|
+
>>> from mindspore import Tensor, nn
|
|
100
|
+
>>> import numpy as np
|
|
101
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
102
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
103
|
+
>>> net = LeNet5()
|
|
99
104
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
|
100
105
|
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
|
|
101
106
|
>>>
|
|
@@ -110,8 +115,7 @@ class WithLossCell(Cell):
|
|
|
110
115
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
|
111
116
|
self._backbone = backbone
|
|
112
117
|
self._loss_fn = loss_fn
|
|
113
|
-
|
|
114
|
-
self._jit_config_dict = backbone.jit_config_dict
|
|
118
|
+
self._get_attr_from_cell(backbone)
|
|
115
119
|
|
|
116
120
|
def construct(self, data, label):
|
|
117
121
|
out = self._backbone(data)
|
|
@@ -124,6 +128,15 @@ class WithLossCell(Cell):
|
|
|
124
128
|
|
|
125
129
|
Returns:
|
|
126
130
|
Cell, the backbone network.
|
|
131
|
+
|
|
132
|
+
Examples:
|
|
133
|
+
>>> from mindspore import nn
|
|
134
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
135
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
136
|
+
>>> net = LeNet5()
|
|
137
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
|
138
|
+
>>> net_with_criterion = nn.WithLossCell(net, loss_fn)
|
|
139
|
+
>>> backbone = net_with_criterion.backbone_network
|
|
127
140
|
"""
|
|
128
141
|
return self._backbone
|
|
129
142
|
|
|
@@ -141,10 +154,10 @@ class WithGradCell(Cell):
|
|
|
141
154
|
|
|
142
155
|
Args:
|
|
143
156
|
network (Cell): The target network to wrap. The network only supports single output.
|
|
144
|
-
loss_fn (Cell): Primitive loss function used to compute gradients. Default: None.
|
|
157
|
+
loss_fn (Cell): Primitive loss function used to compute gradients. Default: ``None`` .
|
|
145
158
|
sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape
|
|
146
|
-
must be same as the `network` output. If None, we will fill one to a same type shape of
|
|
147
|
-
output value. Default: None.
|
|
159
|
+
must be same as the `network` output. If ``None`` , we will fill one to a same type shape of
|
|
160
|
+
output value. Default: ``None`` .
|
|
148
161
|
|
|
149
162
|
Inputs:
|
|
150
163
|
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
@@ -159,8 +172,11 @@ class WithGradCell(Cell):
|
|
|
159
172
|
``Ascend`` ``GPU`` ``CPU``
|
|
160
173
|
|
|
161
174
|
Examples:
|
|
162
|
-
>>>
|
|
163
|
-
>>>
|
|
175
|
+
>>> import mindspore as ms
|
|
176
|
+
>>> from mindspore import nn
|
|
177
|
+
>>> # Defined a network without loss function, taking LeNet5 as an example.
|
|
178
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
179
|
+
>>> net = LeNet5()
|
|
164
180
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
165
181
|
>>> grad_net = nn.WithGradCell(net, loss_fn)
|
|
166
182
|
>>>
|
|
@@ -182,8 +198,7 @@ class WithGradCell(Cell):
|
|
|
182
198
|
else:
|
|
183
199
|
self.network_with_loss = WithLossCell(self.network, self.loss_fn)
|
|
184
200
|
self.network_with_loss.set_train()
|
|
185
|
-
|
|
186
|
-
self._jit_config_dict = network.jit_config_dict
|
|
201
|
+
self._get_attr_from_cell(network)
|
|
187
202
|
|
|
188
203
|
def construct(self, *inputs):
|
|
189
204
|
weights = self.weights
|
|
@@ -202,20 +217,20 @@ class ForwardValueAndGrad(Cell):
|
|
|
202
217
|
The backward graph will be created in the gradient function to calculating gradient.
|
|
203
218
|
|
|
204
219
|
Args:
|
|
205
|
-
network (Cell): The training network.
|
|
220
|
+
network (Union[Cell, Function, MethodType]): The training network.
|
|
206
221
|
weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
|
|
207
|
-
Default: None.
|
|
208
|
-
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
|
209
|
-
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
|
|
210
|
-
If get_all and get_by_list are both False, get the gradient with respect to first input.
|
|
211
|
-
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter
|
|
212
|
-
at the same time in the form of ((gradients with respect to inputs),
|
|
213
|
-
(gradients with respect to parameters)). Default: False.
|
|
222
|
+
Default: ``None`` .
|
|
223
|
+
get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` .
|
|
224
|
+
get_by_list (bool): If ``True`` s, get all the gradients with respect to Parameter variables.
|
|
225
|
+
If get_all and get_by_list are both ``False`` , get the gradient with respect to first input.
|
|
226
|
+
If get_all and get_by_list are both ``True`` , get the gradients with respect to inputs and Parameter
|
|
227
|
+
variables at the same time in the form of ((gradients with respect to inputs),
|
|
228
|
+
(gradients with respect to parameters)). Default: ``False`` .
|
|
214
229
|
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
|
215
|
-
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
|
216
|
-
Default: False.
|
|
217
|
-
If the sens_param is True, a sensitivity (gradient with respect to output) needs to be transferred
|
|
218
|
-
the input parameter.
|
|
230
|
+
If sens_param is ``False`` , a 'ones_like(outputs)' sensitivity will be attached automatically.
|
|
231
|
+
Default: ``False`` .
|
|
232
|
+
If the sens_param is ``True`` , a sensitivity (gradient with respect to output) needs to be transferred
|
|
233
|
+
through the input parameter.
|
|
219
234
|
|
|
220
235
|
Inputs:
|
|
221
236
|
- **\*inputs** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`.
|
|
@@ -232,7 +247,8 @@ class ForwardValueAndGrad(Cell):
|
|
|
232
247
|
|
|
233
248
|
Examples:
|
|
234
249
|
>>> import numpy as np
|
|
235
|
-
>>>
|
|
250
|
+
>>> import mindspore
|
|
251
|
+
>>> from mindspore import Tensor, nn, ops, ParameterTuple, Parameter
|
|
236
252
|
>>>
|
|
237
253
|
>>> class Net(nn.Cell):
|
|
238
254
|
... def __init__(self):
|
|
@@ -284,8 +300,7 @@ class ForwardValueAndGrad(Cell):
|
|
|
284
300
|
self.get_by_list = get_by_list
|
|
285
301
|
self.sens_param = sens_param
|
|
286
302
|
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
|
287
|
-
|
|
288
|
-
self._jit_config_dict = network.jit_config_dict
|
|
303
|
+
self._get_attr_from_cell(network)
|
|
289
304
|
|
|
290
305
|
def construct(self, *inputs):
|
|
291
306
|
grad_inputs = inputs
|
|
@@ -310,7 +325,11 @@ class TrainOneStepCell(Cell):
|
|
|
310
325
|
Args:
|
|
311
326
|
network (Cell): The training network. The network only supports single output.
|
|
312
327
|
optimizer (Union[Cell]): Optimizer for updating the network parameters.
|
|
313
|
-
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is
|
|
328
|
+
sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is
|
|
329
|
+
``None`` , which is ``1.0`` .
|
|
330
|
+
return_grad (bool): Whether to return gradient. If ``True``, it will return the gradient in the form of a dict
|
|
331
|
+
while returning loss. The key of the dict is the parameter name corresponding to the gradient, and value
|
|
332
|
+
is the gradient value. Default value is ``False`` .
|
|
314
333
|
|
|
315
334
|
Inputs:
|
|
316
335
|
- **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
@@ -325,7 +344,10 @@ class TrainOneStepCell(Cell):
|
|
|
325
344
|
``Ascend`` ``GPU`` ``CPU``
|
|
326
345
|
|
|
327
346
|
Examples:
|
|
328
|
-
>>>
|
|
347
|
+
>>> import mindspore.nn as nn
|
|
348
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
349
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
350
|
+
>>> net = LeNet5()
|
|
329
351
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
330
352
|
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
331
353
|
>>> #1) Using the WithLossCell provided by MindSpore
|
|
@@ -333,7 +355,7 @@ class TrainOneStepCell(Cell):
|
|
|
333
355
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
334
356
|
>>>
|
|
335
357
|
>>> #2) Using user-defined WithLossCell
|
|
336
|
-
>>> class MyWithLossCell(Cell):
|
|
358
|
+
>>> class MyWithLossCell(nn.Cell):
|
|
337
359
|
... def __init__(self, backbone, loss_fn):
|
|
338
360
|
... super(MyWithLossCell, self).__init__(auto_prefix=False)
|
|
339
361
|
... self._backbone = backbone
|
|
@@ -351,16 +373,26 @@ class TrainOneStepCell(Cell):
|
|
|
351
373
|
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
|
352
374
|
"""
|
|
353
375
|
|
|
354
|
-
def __init__(self, network, optimizer, sens=
|
|
376
|
+
def __init__(self, network, optimizer, sens=None, return_grad=False):
|
|
355
377
|
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
|
356
378
|
self.network = network
|
|
357
379
|
self.network.set_grad()
|
|
358
380
|
self.optimizer = optimizer
|
|
359
381
|
self.weights = self.optimizer.parameters
|
|
360
382
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
383
|
+
self.grad_no_sens = C.GradOperation(get_by_list=True)
|
|
361
384
|
self.sens = sens
|
|
385
|
+
if self.sens == 0:
|
|
386
|
+
raise ValueError("The input argument of 'sens' can not be 0.")
|
|
387
|
+
self.sense_flag = True
|
|
388
|
+
if self.sens is None:
|
|
389
|
+
self.sense_flag = False
|
|
390
|
+
self.sens = 1.0
|
|
391
|
+
self.return_grad = return_grad
|
|
392
|
+
if return_grad:
|
|
393
|
+
self.weights_name = [i.name for i in self.optimizer.parameters]
|
|
362
394
|
self.reducer_flag = False
|
|
363
|
-
self.grad_reducer =
|
|
395
|
+
self.grad_reducer = nn.Identity()
|
|
364
396
|
self.parallel_mode = _get_parallel_mode()
|
|
365
397
|
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \
|
|
366
398
|
_is_pynative_parallel()
|
|
@@ -379,15 +411,34 @@ class TrainOneStepCell(Cell):
|
|
|
379
411
|
create_group(server_group_name, group_list[current_index])
|
|
380
412
|
group = server_group_name
|
|
381
413
|
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
|
|
382
|
-
|
|
383
|
-
self._jit_config_dict = network.jit_config_dict
|
|
414
|
+
self._get_attr_from_cell(network)
|
|
384
415
|
|
|
385
416
|
def construct(self, *inputs):
|
|
417
|
+
if not self.sense_flag:
|
|
418
|
+
return self._no_sens_impl(*inputs)
|
|
386
419
|
loss = self.network(*inputs)
|
|
387
420
|
sens = F.fill(loss.dtype, loss.shape, self.sens)
|
|
388
421
|
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
|
389
422
|
grads = self.grad_reducer(grads)
|
|
390
423
|
loss = F.depend(loss, self.optimizer(grads))
|
|
424
|
+
if self.return_grad:
|
|
425
|
+
grad_with_param_name = {}
|
|
426
|
+
for index, value in enumerate(grads):
|
|
427
|
+
grad_with_param_name[self.weights_name[index]] = value
|
|
428
|
+
return loss, grad_with_param_name
|
|
429
|
+
return loss
|
|
430
|
+
|
|
431
|
+
def _no_sens_impl(self, *inputs):
|
|
432
|
+
"""construct implementation when the 'sens' parameter is passed in."""
|
|
433
|
+
loss = self.network(*inputs)
|
|
434
|
+
grads = self.grad_no_sens(self.network, self.weights)(*inputs)
|
|
435
|
+
grads = self.grad_reducer(grads)
|
|
436
|
+
loss = F.depend(loss, self.optimizer(grads))
|
|
437
|
+
if self.return_grad:
|
|
438
|
+
grad_with_param_name = {}
|
|
439
|
+
for index, value in enumerate(grads):
|
|
440
|
+
grad_with_param_name[self.weights_name[index]] = value
|
|
441
|
+
return loss, grad_with_param_name
|
|
391
442
|
return loss
|
|
392
443
|
|
|
393
444
|
|
|
@@ -412,7 +463,7 @@ class GetNextSingleOp(Cell):
|
|
|
412
463
|
>>> import mindspore
|
|
413
464
|
>>> from mindspore import ops, nn
|
|
414
465
|
>>> from mindspore import dataset as ds
|
|
415
|
-
>>> from mindspore
|
|
466
|
+
>>> from mindspore import dtype as mstype
|
|
416
467
|
>>>
|
|
417
468
|
>>> data_path = "/path/to/MNIST_Data/train/"
|
|
418
469
|
>>> train_dataset = ds.MnistDataset(data_path, num_samples=10)
|
|
@@ -459,8 +510,7 @@ class _VirtualDatasetCell(Cell):
|
|
|
459
510
|
super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
|
|
460
511
|
self._backbone = backbone
|
|
461
512
|
self._virtual_dataset = _VirtualDataset()
|
|
462
|
-
|
|
463
|
-
self._jit_config_dict = backbone.jit_config_dict
|
|
513
|
+
self._get_attr_from_cell(backbone)
|
|
464
514
|
|
|
465
515
|
def construct(self, *inputs):
|
|
466
516
|
output = self._virtual_dataset(*inputs)
|
|
@@ -469,6 +519,8 @@ class _VirtualDatasetCell(Cell):
|
|
|
469
519
|
|
|
470
520
|
@_primexpr
|
|
471
521
|
def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
|
|
522
|
+
if F.isconstant(input_shape[0]) is False:
|
|
523
|
+
return
|
|
472
524
|
if input_shape[0] % micro_size != 0:
|
|
473
525
|
raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
|
|
474
526
|
f"divided by micro size({micro_size})")
|
|
@@ -493,8 +545,8 @@ class _MicroBatch(Cell):
|
|
|
493
545
|
for each_input in inputs:
|
|
494
546
|
input_shape = self.shape(each_input)
|
|
495
547
|
_check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size)
|
|
496
|
-
micro_batch_begin =
|
|
497
|
-
micro_batch_end = (
|
|
548
|
+
micro_batch_begin = (input_shape[0] // self.micro_size) * i
|
|
549
|
+
micro_batch_end = (input_shape[0] // self.micro_size) * (i + 1)
|
|
498
550
|
strided_slice_begin = (micro_batch_begin,)
|
|
499
551
|
strided_slice_strides = (1,)
|
|
500
552
|
for _ in range(len(input_shape) - 1):
|
|
@@ -520,7 +572,7 @@ class MicroBatchInterleaved(Cell):
|
|
|
520
572
|
|
|
521
573
|
Args:
|
|
522
574
|
network (Cell): The target network to wrap.
|
|
523
|
-
interleave_num (int, optional): split num of batch size. Default: 2.
|
|
575
|
+
interleave_num (int, optional): split num of batch size. Default: ``2`` .
|
|
524
576
|
|
|
525
577
|
Inputs:
|
|
526
578
|
tuple[Tensor]. It's the same with the input of the `network` .
|
|
@@ -532,8 +584,11 @@ class MicroBatchInterleaved(Cell):
|
|
|
532
584
|
``Ascend`` ``GPU``
|
|
533
585
|
|
|
534
586
|
Examples:
|
|
535
|
-
>>>
|
|
536
|
-
>>>
|
|
587
|
+
>>> import mindspore.nn as nn
|
|
588
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
589
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
590
|
+
>>> net = LeNet5()
|
|
591
|
+
>>> net = nn.MicroBatchInterleaved(net, 2)
|
|
537
592
|
"""
|
|
538
593
|
def __init__(self, network, interleave_num=2):
|
|
539
594
|
super(MicroBatchInterleaved, self).__init__(auto_prefix=False)
|
|
@@ -552,8 +607,7 @@ class MicroBatchInterleaved(Cell):
|
|
|
552
607
|
interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
|
|
553
608
|
interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
|
|
554
609
|
self.interleave_inputs.append(interleave_data)
|
|
555
|
-
|
|
556
|
-
self._jit_config_dict = network.jit_config_dict
|
|
610
|
+
self._get_attr_from_cell(network)
|
|
557
611
|
|
|
558
612
|
def construct(self, *inputs):
|
|
559
613
|
output = 0.0
|
|
@@ -578,8 +632,11 @@ class PipelineCell(Cell):
|
|
|
578
632
|
``Ascend`` ``GPU``
|
|
579
633
|
|
|
580
634
|
Examples:
|
|
581
|
-
>>>
|
|
582
|
-
>>>
|
|
635
|
+
>>> import mindspore.nn as nn
|
|
636
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
637
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
638
|
+
>>> net = LeNet5()
|
|
639
|
+
>>> net = nn.PipelineCell(net, 4)
|
|
583
640
|
"""
|
|
584
641
|
def __init__(self, network, micro_size):
|
|
585
642
|
super(PipelineCell, self).__init__(auto_prefix=False)
|
|
@@ -587,13 +644,64 @@ class PipelineCell(Cell):
|
|
|
587
644
|
self.micro_inputs = nn.CellList()
|
|
588
645
|
self.micro_size = micro_size
|
|
589
646
|
self.add_list = []
|
|
647
|
+
if not isinstance(micro_size, int):
|
|
648
|
+
raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
|
|
649
|
+
"but got the type : {}.".format(type(micro_size)))
|
|
650
|
+
if micro_size <= 0:
|
|
651
|
+
raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
|
|
652
|
+
"but got {}.".format(micro_size))
|
|
590
653
|
for i in range(micro_size):
|
|
591
654
|
micro_input = _MicroBatch(micro_size)
|
|
592
655
|
self.micro_inputs.append(micro_input)
|
|
593
656
|
self.add = P.Add().add_prim_attr("pipeline_end", i)
|
|
594
657
|
self.add_list.append(self.add)
|
|
595
|
-
|
|
596
|
-
|
|
658
|
+
self._get_attr_from_cell(network)
|
|
659
|
+
|
|
660
|
+
def construct(self, *inputs):
|
|
661
|
+
ret = None
|
|
662
|
+
for i in range(self.micro_size):
|
|
663
|
+
micro_input = self.micro_inputs[i](i, *inputs)
|
|
664
|
+
output = self.network(*micro_input)
|
|
665
|
+
if ret is not None:
|
|
666
|
+
ret = self.add_list[i](ret, output)
|
|
667
|
+
else:
|
|
668
|
+
ret = output
|
|
669
|
+
return ret
|
|
670
|
+
|
|
671
|
+
class GradAccumulationCell(Cell):
|
|
672
|
+
"""
|
|
673
|
+
Wrap the network with Micro Batch.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
network (Cell): The target network to wrap.
|
|
677
|
+
micro_size (int): MicroBatch size.
|
|
678
|
+
|
|
679
|
+
Supported Platforms:
|
|
680
|
+
``Ascend`` ``GPU``
|
|
681
|
+
|
|
682
|
+
Examples:
|
|
683
|
+
>>> net = Net()
|
|
684
|
+
>>> net = GradAccumulationCell(net, 4)
|
|
685
|
+
"""
|
|
686
|
+
def __init__(self, network, micro_size):
|
|
687
|
+
super(GradAccumulationCell, self).__init__(auto_prefix=False)
|
|
688
|
+
self.network = network
|
|
689
|
+
self.micro_inputs = nn.CellList()
|
|
690
|
+
self.micro_size = micro_size
|
|
691
|
+
self.add_list = []
|
|
692
|
+
if not isinstance(micro_size, int):
|
|
693
|
+
raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
|
|
694
|
+
"but got the type : {}.".format(type(micro_size)))
|
|
695
|
+
if micro_size <= 0:
|
|
696
|
+
raise ValueError("For 'GradAccumulationCell', the argument 'micro_size' must be large than 0, "
|
|
697
|
+
"but got {}.".format(micro_size))
|
|
698
|
+
for i in range(micro_size):
|
|
699
|
+
micro_input = _MicroBatch(micro_size)
|
|
700
|
+
micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size)
|
|
701
|
+
self.micro_inputs.append(micro_input)
|
|
702
|
+
self.add = P.Add().add_prim_attr("forward_end", i)
|
|
703
|
+
self.add_list.append(self.add)
|
|
704
|
+
self._get_attr_from_cell(network)
|
|
597
705
|
|
|
598
706
|
def construct(self, *inputs):
|
|
599
707
|
ret = None
|
|
@@ -613,23 +721,37 @@ def _pipeline_clear_grad(accu_grad, grad):
|
|
|
613
721
|
return F.assign(accu_grad, zeros)
|
|
614
722
|
|
|
615
723
|
|
|
616
|
-
class
|
|
724
|
+
class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
617
725
|
"""
|
|
618
726
|
Wraps the network with an optimizer in pipeline mode.
|
|
619
727
|
"""
|
|
620
|
-
def __init__(self, network, optimizer, sens=
|
|
621
|
-
super(
|
|
728
|
+
def __init__(self, network, optimizer, sens=None):
|
|
729
|
+
super(_TrainGradAccuStepCell, self).__init__(network, optimizer, sens)
|
|
622
730
|
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
|
623
731
|
self.hyper_map = ops.HyperMap()
|
|
624
732
|
self.opt_shard = _get_enable_parallel_optimizer()
|
|
625
|
-
|
|
626
|
-
self._jit_config_dict = network.jit_config_dict
|
|
733
|
+
self._get_attr_from_cell(network)
|
|
627
734
|
|
|
628
735
|
def construct(self, *inputs):
|
|
629
|
-
|
|
736
|
+
if not self.sense_flag:
|
|
737
|
+
return self._no_sens_impl(*inputs)
|
|
630
738
|
loss = self.network(*inputs)
|
|
631
|
-
sens = ops.
|
|
632
|
-
grads = self.grad(self.network, weights)(*inputs, sens)
|
|
739
|
+
sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
|
|
740
|
+
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
|
741
|
+
accu_grads = ops.depend(self.accu_grads, grads)
|
|
742
|
+
if self.opt_shard:
|
|
743
|
+
succ = self.optimizer(grads)
|
|
744
|
+
else:
|
|
745
|
+
succ = self.optimizer(accu_grads)
|
|
746
|
+
loss = ops.depend(loss, succ)
|
|
747
|
+
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
|
|
748
|
+
loss = ops.depend(loss, clear)
|
|
749
|
+
return loss
|
|
750
|
+
|
|
751
|
+
def _no_sens_impl(self, *inputs):
|
|
752
|
+
"""construct implementation when the 'sens' parameter is passed in."""
|
|
753
|
+
loss = self.network(*inputs)
|
|
754
|
+
grads = self.grad_no_sens(self.network, self.weights)(*inputs)
|
|
633
755
|
accu_grads = ops.depend(self.accu_grads, grads)
|
|
634
756
|
if self.opt_shard:
|
|
635
757
|
succ = self.optimizer(grads)
|
|
@@ -657,16 +779,18 @@ class VirtualDatasetCellTriple(Cell):
|
|
|
657
779
|
backbone (Cell): The target network to wrap.
|
|
658
780
|
|
|
659
781
|
Examples:
|
|
660
|
-
>>>
|
|
661
|
-
>>>
|
|
782
|
+
>>> import mindspore.nn as nn
|
|
783
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
784
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
785
|
+
>>> net = LeNet5()
|
|
786
|
+
>>> net = nn.VirtualDatasetCellTriple(net)
|
|
662
787
|
"""
|
|
663
788
|
|
|
664
789
|
def __init__(self, backbone):
|
|
665
790
|
super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
|
|
666
791
|
logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.")
|
|
667
792
|
self._backbone = backbone
|
|
668
|
-
|
|
669
|
-
self._jit_config_dict = backbone.jit_config_dict
|
|
793
|
+
self._get_attr_from_cell(backbone)
|
|
670
794
|
|
|
671
795
|
def construct(self, a, b, c):
|
|
672
796
|
return self._backbone(a, b, c)
|
|
@@ -681,7 +805,7 @@ class WithEvalCell(Cell):
|
|
|
681
805
|
Args:
|
|
682
806
|
network (Cell): The forward network.
|
|
683
807
|
loss_fn (Cell): The loss function.
|
|
684
|
-
add_cast_fp32 (bool): Whether to adjust the data type to float32. Default: False.
|
|
808
|
+
add_cast_fp32 (bool): Whether to adjust the data type to float32. Default: ``False`` .
|
|
685
809
|
|
|
686
810
|
Inputs:
|
|
687
811
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
|
@@ -698,8 +822,10 @@ class WithEvalCell(Cell):
|
|
|
698
822
|
``Ascend`` ``GPU`` ``CPU``
|
|
699
823
|
|
|
700
824
|
Examples:
|
|
701
|
-
>>>
|
|
702
|
-
>>>
|
|
825
|
+
>>> import mindspore.nn as nn
|
|
826
|
+
>>> # Define a forward network without loss function, taking LeNet5 as an example.
|
|
827
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
828
|
+
>>> net = LeNet5()
|
|
703
829
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
704
830
|
>>> eval_net = nn.WithEvalCell(net, loss_fn)
|
|
705
831
|
"""
|
|
@@ -709,8 +835,7 @@ class WithEvalCell(Cell):
|
|
|
709
835
|
self._network = network
|
|
710
836
|
self._loss_fn = loss_fn
|
|
711
837
|
self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
|
|
712
|
-
|
|
713
|
-
self._jit_config_dict = network.jit_config_dict
|
|
838
|
+
self._get_attr_from_cell(network)
|
|
714
839
|
|
|
715
840
|
def construct(self, data, label):
|
|
716
841
|
outputs = self._network(data)
|
|
@@ -297,11 +297,11 @@ class DistributedGradReducer(Cell):
|
|
|
297
297
|
parameters (list): the parameters to be updated.
|
|
298
298
|
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
|
299
299
|
When it is not specified, using the configuration `gradients_mean` in auto_parallel_context.
|
|
300
|
-
Default: None.
|
|
301
|
-
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
|
|
302
|
-
fusion_type (int): The type of all reduce fusion. Default: 1.
|
|
300
|
+
Default: ``None`` .
|
|
301
|
+
degree (int): The mean coefficient. Usually it equals to device number. Default: ``None`` .
|
|
302
|
+
fusion_type (int): The type of all reduce fusion. Default: ``1`` .
|
|
303
303
|
group (str): The communication group to work on. Normally, the group should be created by create_group,
|
|
304
|
-
otherwise, using the default group. Default: GlobalComm.WORLD_COMM_GROUP.
|
|
304
|
+
otherwise, using the default group. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
|
|
305
305
|
|
|
306
306
|
Raises:
|
|
307
307
|
ValueError: If degree is not an int or less than 0.
|
|
@@ -314,21 +314,22 @@ class DistributedGradReducer(Cell):
|
|
|
314
314
|
Before running the following examples, you need to configure the communication environment variables.
|
|
315
315
|
|
|
316
316
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
317
|
-
Please see the `
|
|
318
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
317
|
+
Please see the `rank table Startup
|
|
318
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
319
319
|
for more details.
|
|
320
320
|
|
|
321
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
322
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
321
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
322
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
323
|
+
|
|
324
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
325
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
323
326
|
|
|
324
327
|
This example should be run with multiple devices.
|
|
325
328
|
|
|
326
329
|
>>> import numpy as np
|
|
327
330
|
>>> import mindspore as ms
|
|
328
331
|
>>> from mindspore.communication import init
|
|
329
|
-
>>> from mindspore import ops
|
|
330
|
-
>>> from mindspore import Parameter, Tensor
|
|
331
|
-
>>> from mindspore import nn
|
|
332
|
+
>>> from mindspore import Parameter, Tensor, ops, nn
|
|
332
333
|
>>>
|
|
333
334
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
334
335
|
>>> init()
|
|
@@ -358,7 +359,7 @@ class DistributedGradReducer(Cell):
|
|
|
358
359
|
... def construct(self, *args):
|
|
359
360
|
... weights = self.weights
|
|
360
361
|
... loss = self.network(*args)
|
|
361
|
-
... sens =
|
|
362
|
+
... sens = F.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
|
|
362
363
|
... grads = self.grad(self.network, weights)(*args, sens)
|
|
363
364
|
... if self.reducer_flag:
|
|
364
365
|
... # apply grad reducer on grads
|
|
@@ -391,6 +392,7 @@ class DistributedGradReducer(Cell):
|
|
|
391
392
|
|
|
392
393
|
def __init__(self, parameters, mean=None, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP):
|
|
393
394
|
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
|
395
|
+
self._check_parallel_mode()
|
|
394
396
|
self.map_ = C.Map()
|
|
395
397
|
self.mean = mean
|
|
396
398
|
if mean is None:
|
|
@@ -457,3 +459,10 @@ class DistributedGradReducer(Cell):
|
|
|
457
459
|
self.allreduce), self.allreduce_filter, grads)
|
|
458
460
|
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
|
|
459
461
|
return new_grad
|
|
462
|
+
|
|
463
|
+
def _check_parallel_mode(self):
|
|
464
|
+
"""check parallel mode"""
|
|
465
|
+
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
|
466
|
+
if context.get_context('mode') == context.GRAPH_MODE and parallel_mode in (
|
|
467
|
+
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL):
|
|
468
|
+
raise RuntimeError("{} can not use DistributedGradReducer in graph mode".format(parallel_mode))
|