mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -23,8 +23,8 @@ from collections import OrderedDict
|
|
|
23
23
|
from types import FunctionType, MethodType
|
|
24
24
|
import numpy
|
|
25
25
|
|
|
26
|
-
import mindspore.dataset as ds
|
|
27
26
|
from mindspore._checkparam import args_type_check
|
|
27
|
+
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
28
28
|
from mindspore import log as logger
|
|
29
29
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
30
30
|
from mindspore.common.hook_handle import HookHandle
|
|
@@ -42,29 +42,16 @@ from mindspore.ops.primitive import Primitive
|
|
|
42
42
|
from mindspore.ops.operations import _inner_ops as inner
|
|
43
43
|
from mindspore.parallel.shard import Shard
|
|
44
44
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
45
|
+
from mindspore.common._decorator import deprecated
|
|
46
|
+
from mindspore._c_expression import PackExpander
|
|
47
|
+
from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
|
|
45
48
|
|
|
46
49
|
|
|
47
50
|
def _check_args(args):
|
|
48
51
|
"""Check the input args's type"""
|
|
49
|
-
index = 1
|
|
50
52
|
for item in args:
|
|
51
53
|
if isinstance(item, Tensor) and item.has_init:
|
|
52
54
|
item.init_data()
|
|
53
|
-
elif isinstance(item, numpy.ndarray):
|
|
54
|
-
suffix = "th"
|
|
55
|
-
if index == 1:
|
|
56
|
-
suffix = "st"
|
|
57
|
-
elif index == 2:
|
|
58
|
-
suffix = "nd"
|
|
59
|
-
elif index == 3:
|
|
60
|
-
suffix = "rd"
|
|
61
|
-
|
|
62
|
-
input_index = str(index) + suffix
|
|
63
|
-
raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
|
|
64
|
-
f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
|
|
65
|
-
f"), and tuple or list containing only these types, and dict whose values are these "
|
|
66
|
-
f"types, but the {input_index} arg type is {type(item)}.")
|
|
67
|
-
index += 1
|
|
68
55
|
|
|
69
56
|
|
|
70
57
|
class Cell(Cell_):
|
|
@@ -77,15 +64,25 @@ class Cell(Cell_):
|
|
|
77
64
|
graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
|
|
78
65
|
PYNATIVE_MODE (dynamic graph mode).
|
|
79
66
|
|
|
67
|
+
.. note::
|
|
68
|
+
Cell is the inference mode by default. For a class that inherits a Cell,
|
|
69
|
+
if the training and inference have different structures, the subclass performs the inference branch by default.
|
|
70
|
+
To set the training mode, refer to `mindspore.nn.Cell.set_train` .
|
|
71
|
+
|
|
72
|
+
.. warning::
|
|
73
|
+
In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
|
|
74
|
+
named 'phase' or 'cells', otherwise, an error will be raised.
|
|
75
|
+
|
|
80
76
|
Args:
|
|
81
77
|
auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
|
|
82
|
-
affects the names of parameters in the `Cell`. If set to True, the parameter name will be
|
|
83
|
-
automatically prefixed, otherwise not. In general, the backbone network should be set to
|
|
84
|
-
otherwise the duplicate name problem will appear. The cell to train the backbone
|
|
85
|
-
optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to
|
|
86
|
-
parameter name in backbone will be changed by mistake.
|
|
78
|
+
affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
|
|
79
|
+
automatically prefixed, otherwise not. In general, the backbone network should be set to
|
|
80
|
+
``True`` , otherwise the duplicate name problem will appear. The cell to train the backbone
|
|
81
|
+
network, such as optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to
|
|
82
|
+
``False`` , otherwise the parameter name in backbone will be changed by mistake.
|
|
83
|
+
Default: ``True`` .
|
|
87
84
|
flags (dict, optional): Network configuration information, currently it is used for the binding of network
|
|
88
|
-
and dataset. Users can also customize network attributes by this parameter. Default: None.
|
|
85
|
+
and dataset. Users can also customize network attributes by this parameter. Default: ``None`` .
|
|
89
86
|
|
|
90
87
|
Supported Platforms:
|
|
91
88
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -167,7 +164,9 @@ class Cell(Cell_):
|
|
|
167
164
|
self.saved_dynamic_shape = None
|
|
168
165
|
self._jit_config_dict = dict()
|
|
169
166
|
self.grad_ops_label = False
|
|
170
|
-
self.
|
|
167
|
+
self.ge_sync_data = False
|
|
168
|
+
self._is_check_and_refresh = False
|
|
169
|
+
self._amp_level = ""
|
|
171
170
|
|
|
172
171
|
def __getstate__(self):
|
|
173
172
|
base = Cell_.__getstate__(self)
|
|
@@ -199,6 +198,23 @@ class Cell(Cell_):
|
|
|
199
198
|
def param_prefix(self):
|
|
200
199
|
"""
|
|
201
200
|
Param prefix is the prefix of current cell's direct child parameter.
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
>>> import mindspore as ms
|
|
204
|
+
>>> from mindspore import Tensor, nn
|
|
205
|
+
...
|
|
206
|
+
>>> class Net(nn.Cell):
|
|
207
|
+
... def __init__(self):
|
|
208
|
+
... super(Net, self).__init__()
|
|
209
|
+
... self.dense = nn.Dense(2, 2)
|
|
210
|
+
...
|
|
211
|
+
... def construct(self, x):
|
|
212
|
+
... x = self.dense(x)
|
|
213
|
+
... return x
|
|
214
|
+
>>> net = Net()
|
|
215
|
+
>>> net.update_cell_prefix()
|
|
216
|
+
>>> print(net.dense.param_prefix)
|
|
217
|
+
dense
|
|
202
218
|
"""
|
|
203
219
|
return self._param_prefix
|
|
204
220
|
|
|
@@ -206,6 +222,10 @@ class Cell(Cell_):
|
|
|
206
222
|
def bprop_debug(self):
|
|
207
223
|
"""
|
|
208
224
|
Get whether cell custom bprop debug is enabled.
|
|
225
|
+
|
|
226
|
+
Tutorial Examples:
|
|
227
|
+
- `Cell and Parameter - Custom Cell Reverse
|
|
228
|
+
<https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#custom-cell-reverse>`_
|
|
209
229
|
"""
|
|
210
230
|
return self._bprop_debug
|
|
211
231
|
|
|
@@ -220,7 +240,7 @@ class Cell(Cell_):
|
|
|
220
240
|
and add to graph when bprop debug is false.
|
|
221
241
|
|
|
222
242
|
Args:
|
|
223
|
-
value (bool): Specifies whether to enable bprop debug. Default: False
|
|
243
|
+
value (bool): Specifies whether to enable bprop debug. Default: ``False``.
|
|
224
244
|
"""
|
|
225
245
|
if not isinstance(value, bool):
|
|
226
246
|
raise TypeError(f"For 'Cell', the property 'bprop_debug' must be bool type, but got type {type(value)}.")
|
|
@@ -312,6 +332,21 @@ class Cell(Cell_):
|
|
|
312
332
|
for item in self.trainable_params():
|
|
313
333
|
item.add_pipeline_stage(value)
|
|
314
334
|
|
|
335
|
+
@property
|
|
336
|
+
def pipeline_segment(self):
|
|
337
|
+
return self._pipeline_segment
|
|
338
|
+
|
|
339
|
+
@pipeline_segment.setter
|
|
340
|
+
def pipeline_segment(self, value):
|
|
341
|
+
if not isinstance(value, int) or isinstance(value, bool):
|
|
342
|
+
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
343
|
+
"must be int type, but got type : {}".format(type(value)))
|
|
344
|
+
|
|
345
|
+
if value < 0:
|
|
346
|
+
raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
347
|
+
"can not be less than 0, but got {}".format(value))
|
|
348
|
+
self._pipeline_segment = value
|
|
349
|
+
|
|
315
350
|
@property
|
|
316
351
|
def parallel_parameter_merge_net_dict(self):
|
|
317
352
|
return self._parallel_parameter_merge_net_dict
|
|
@@ -348,13 +383,14 @@ class Cell(Cell_):
|
|
|
348
383
|
if '_params_list' in self.__dict__:
|
|
349
384
|
params_list = self.__dict__['_params_list']
|
|
350
385
|
if name in params_list:
|
|
351
|
-
return
|
|
386
|
+
return params_list[name]
|
|
352
387
|
raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
|
353
388
|
|
|
354
389
|
def __del__(self):
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
390
|
+
if isinstance(cells_compile_cache, dict):
|
|
391
|
+
# while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
|
|
392
|
+
# here using pop(id(self), None) to avoid KeyError exception
|
|
393
|
+
cells_compile_cache.pop(id(self), None)
|
|
358
394
|
try:
|
|
359
395
|
if self.compile_cache:
|
|
360
396
|
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
@@ -367,11 +403,11 @@ class Cell(Cell_):
|
|
|
367
403
|
del self._params[name]
|
|
368
404
|
elif name in self._cells:
|
|
369
405
|
del self._cells[name]
|
|
406
|
+
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
407
|
+
del self._params_list[name]
|
|
408
|
+
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
409
|
+
del self._tensor_list[name]
|
|
370
410
|
else:
|
|
371
|
-
if '_params_list' in self.__dict__ and name in self._params_list:
|
|
372
|
-
del self._params_list[name]
|
|
373
|
-
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
374
|
-
del self._tensor_list[name]
|
|
375
411
|
object.__delattr__(self, name)
|
|
376
412
|
self._attr_synced = False
|
|
377
413
|
|
|
@@ -383,7 +419,8 @@ class Cell(Cell_):
|
|
|
383
419
|
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
|
384
420
|
elif isinstance(item, float):
|
|
385
421
|
res.append(self.cast(item, dst_type))
|
|
386
|
-
elif hasattr(item, "dtype") and item.dtype in
|
|
422
|
+
elif hasattr(item, "dtype") and item.dtype in \
|
|
423
|
+
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
387
424
|
res.append(self.cast(item, dst_type))
|
|
388
425
|
else:
|
|
389
426
|
res.append(item)
|
|
@@ -438,7 +475,7 @@ class Cell(Cell_):
|
|
|
438
475
|
if self._enable_forward_pre_hook:
|
|
439
476
|
cast_inputs = self._run_forward_pre_hook(cast_inputs)
|
|
440
477
|
if self._enable_backward_hook:
|
|
441
|
-
output = self._backward_hook_construct(*cast_inputs)
|
|
478
|
+
output = self._backward_hook_construct(*cast_inputs, **kwargs)
|
|
442
479
|
elif hasattr(self, "_shard_fn"):
|
|
443
480
|
output = self._shard_fn(*cast_inputs, **kwargs)
|
|
444
481
|
else:
|
|
@@ -546,19 +583,19 @@ class Cell(Cell_):
|
|
|
546
583
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
|
|
547
584
|
defines the layout of the corresponding input and None represents a data parallel strategy.
|
|
548
585
|
out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
|
|
549
|
-
It is not in use right now. Default: None.
|
|
586
|
+
It is not in use right now. Default: ``None`` .
|
|
550
587
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
551
588
|
defines the layout of the parameter like "param_name: layout".
|
|
552
589
|
The key is a parameter name of type 'str'.
|
|
553
590
|
The value is a 1-D integer tuple, indicating the corresponding layout.
|
|
554
591
|
If the parameter name is incorrect or the corresponding parameter
|
|
555
592
|
has been set, the parameter setting will be ignored.
|
|
556
|
-
Default: None.
|
|
593
|
+
Default: ``None`` .
|
|
557
594
|
device (string): Select a certain device target. It is not in use right now.
|
|
558
|
-
Support ["CPU", "GPU", "Ascend"]. Default: "Ascend".
|
|
595
|
+
Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
|
|
559
596
|
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
|
|
560
597
|
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
561
|
-
use right now. Support ["0", "1", "2"]. Default:
|
|
598
|
+
use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
|
|
562
599
|
|
|
563
600
|
Returns:
|
|
564
601
|
Cell, the cell itself.
|
|
@@ -627,6 +664,13 @@ class Cell(Cell_):
|
|
|
627
664
|
args = bound_arguments.args
|
|
628
665
|
kwargs = bound_arguments.kwargs
|
|
629
666
|
|
|
667
|
+
if PackFunc.is_tracing():
|
|
668
|
+
return self._run_tracefunc(*args, **kwargs)
|
|
669
|
+
|
|
670
|
+
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
|
|
671
|
+
self.check_names_and_refresh_name()
|
|
672
|
+
self._is_check_and_refresh = True
|
|
673
|
+
|
|
630
674
|
# Run in Graph mode.
|
|
631
675
|
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
632
676
|
self._check_construct_args(*args)
|
|
@@ -646,7 +690,7 @@ class Cell(Cell_):
|
|
|
646
690
|
_check_args(args)
|
|
647
691
|
self._check_cell_flags_in_pynative()
|
|
648
692
|
|
|
649
|
-
if self.requires_grad:
|
|
693
|
+
if self.requires_grad and _pynative_executor.enable_grad():
|
|
650
694
|
_pynative_executor.set_grad_flag(True)
|
|
651
695
|
|
|
652
696
|
if self._dynamic_shape_inputs is not None:
|
|
@@ -881,16 +925,16 @@ class Cell(Cell_):
|
|
|
881
925
|
Examples:
|
|
882
926
|
>>> import numpy as np
|
|
883
927
|
>>> import mindspore as ms
|
|
884
|
-
>>> from mindspore import nn, Tensor
|
|
928
|
+
>>> from mindspore import nn, Tensor
|
|
885
929
|
>>>
|
|
886
|
-
>>> class
|
|
930
|
+
>>> class ReluNet(nn.Cell):
|
|
887
931
|
... def __init__(self):
|
|
888
|
-
... super(
|
|
932
|
+
... super(ReluNet, self).__init__()
|
|
889
933
|
... self.relu = nn.ReLU()
|
|
890
934
|
... def construct(self, x):
|
|
891
935
|
... return self.relu(x)
|
|
892
936
|
>>>
|
|
893
|
-
>>> net =
|
|
937
|
+
>>> net = ReluNet()
|
|
894
938
|
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
895
939
|
>>> net.set_inputs(input_dyn)
|
|
896
940
|
>>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
|
|
@@ -899,15 +943,10 @@ class Cell(Cell_):
|
|
|
899
943
|
if self.grad_ops_label:
|
|
900
944
|
logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
|
|
901
945
|
f'generated.')
|
|
902
|
-
for ele in inputs:
|
|
903
|
-
if isinstance(ele, str):
|
|
904
|
-
raise TypeError(f"For element in 'set_inputs', the type must not be str.")
|
|
905
946
|
self._dynamic_shape_inputs = inputs
|
|
906
947
|
self._check_construct_args(*inputs)
|
|
907
|
-
if self._dynamic_shape_inputs:
|
|
908
|
-
ds.config.set_dynamic_shape(True)
|
|
909
948
|
if context._get_mode() == context.PYNATIVE_MODE:
|
|
910
|
-
_pynative_executor.set_dynamic_input(self)
|
|
949
|
+
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
911
950
|
|
|
912
951
|
def get_inputs(self):
|
|
913
952
|
"""
|
|
@@ -918,6 +957,26 @@ class Cell(Cell_):
|
|
|
918
957
|
|
|
919
958
|
.. warning::
|
|
920
959
|
This is an experimental API that is subject to change or deletion.
|
|
960
|
+
|
|
961
|
+
Examples:
|
|
962
|
+
>>> import numpy as np
|
|
963
|
+
>>> import mindspore as ms
|
|
964
|
+
>>> from mindspore import nn, Tensor
|
|
965
|
+
>>>
|
|
966
|
+
>>> class ReluNet(nn.Cell):
|
|
967
|
+
... def __init__(self):
|
|
968
|
+
... super(ReluNet, self).__init__()
|
|
969
|
+
... self.relu = nn.ReLU()
|
|
970
|
+
... def construct(self, x):
|
|
971
|
+
... return self.relu(x)
|
|
972
|
+
>>>
|
|
973
|
+
>>> net = ReluNet()
|
|
974
|
+
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
975
|
+
>>> net.set_inputs(input_dyn)
|
|
976
|
+
>>> get_inputs = net.get_inputs()
|
|
977
|
+
>>> print(get_inputs)
|
|
978
|
+
(Tensor(shape=[3, -1], dtype=Float32, value= ),)
|
|
979
|
+
|
|
921
980
|
"""
|
|
922
981
|
|
|
923
982
|
return self._dynamic_shape_inputs
|
|
@@ -930,6 +989,10 @@ class Cell(Cell_):
|
|
|
930
989
|
args (tuple): Args of the Cell object.
|
|
931
990
|
kwargs (dict): Kwargs of the Cell object.
|
|
932
991
|
"""
|
|
992
|
+
# this is used only for test
|
|
993
|
+
if is_auto_dynamic() and (self._dynamic_shape_inputs is None or self._dynamic_shape_inputs[0] is None):
|
|
994
|
+
self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
|
|
995
|
+
|
|
933
996
|
if self._dynamic_shape_inputs is None:
|
|
934
997
|
_cell_graph_executor.compile(self, phase=self.phase,
|
|
935
998
|
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
@@ -955,7 +1018,7 @@ class Cell(Cell_):
|
|
|
955
1018
|
Object, the result of executing.
|
|
956
1019
|
"""
|
|
957
1020
|
self.compile(*args, **kwargs)
|
|
958
|
-
|
|
1021
|
+
self.add_flags(ge_sync_data=False)
|
|
959
1022
|
new_args = _get_args_for_run(self, args, kwargs)
|
|
960
1023
|
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
961
1024
|
|
|
@@ -969,7 +1032,8 @@ class Cell(Cell_):
|
|
|
969
1032
|
logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
|
|
970
1033
|
|
|
971
1034
|
def exec_checkpoint_graph(self):
|
|
972
|
-
"""Executes saving checkpoint graph operation."""
|
|
1035
|
+
"""Executes GE saving checkpoint graph operation."""
|
|
1036
|
+
self.add_flags(ge_sync_data=True)
|
|
973
1037
|
_cell_graph_executor(self, phase='save')
|
|
974
1038
|
|
|
975
1039
|
def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
|
|
@@ -982,11 +1046,28 @@ class Cell(Cell_):
|
|
|
982
1046
|
Args:
|
|
983
1047
|
param_name (str): Name of the parameter.
|
|
984
1048
|
param (Parameter): Parameter to be inserted to the cell.
|
|
985
|
-
check_name_contain_dot (bool): Determines whether the name input is compatible. Default: True.
|
|
1049
|
+
check_name_contain_dot (bool): Determines whether the name input is compatible. Default: ``True`` .
|
|
986
1050
|
|
|
987
1051
|
Raises:
|
|
988
1052
|
KeyError: If the name of parameter is null or contains dot.
|
|
989
1053
|
TypeError: If the type of parameter is not Parameter.
|
|
1054
|
+
|
|
1055
|
+
Examples:
|
|
1056
|
+
>>> import mindspore as ms
|
|
1057
|
+
>>> from mindspore import Tensor, nn, Parameter
|
|
1058
|
+
...
|
|
1059
|
+
>>> class Net(nn.Cell):
|
|
1060
|
+
... def __init__(self):
|
|
1061
|
+
... super(Net, self).__init__()
|
|
1062
|
+
... self.relu = nn.ReLU()
|
|
1063
|
+
...
|
|
1064
|
+
... def construct(self, x):
|
|
1065
|
+
... x = self.relu(x)
|
|
1066
|
+
... return x
|
|
1067
|
+
>>> net = Net()
|
|
1068
|
+
>>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
|
|
1069
|
+
>>> print(net.bias)
|
|
1070
|
+
Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
|
|
990
1071
|
"""
|
|
991
1072
|
if not param_name:
|
|
992
1073
|
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
@@ -1000,6 +1081,9 @@ class Cell(Cell_):
|
|
|
1000
1081
|
if not isinstance(param, Parameter) and param is not None:
|
|
1001
1082
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1002
1083
|
f"but got {type(param)}.")
|
|
1084
|
+
if param is None:
|
|
1085
|
+
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
|
|
1086
|
+
f"but got None.")
|
|
1003
1087
|
if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
|
|
1004
1088
|
param.name = param_name
|
|
1005
1089
|
self._params[param_name] = param
|
|
@@ -1041,6 +1125,18 @@ class Cell(Cell_):
|
|
|
1041
1125
|
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
|
1042
1126
|
TypeError: If type of `child_name` is not str.
|
|
1043
1127
|
TypeError: Child Cell's type is incorrect.
|
|
1128
|
+
|
|
1129
|
+
Examples:
|
|
1130
|
+
>>> import mindspore as ms
|
|
1131
|
+
>>> from mindspore import Tensor, nn
|
|
1132
|
+
...
|
|
1133
|
+
>>> net1 = nn.ReLU()
|
|
1134
|
+
>>> net2 = nn.Dense(2, 2)
|
|
1135
|
+
>>> net1.insert_child_to_cell("child", net2)
|
|
1136
|
+
>>> print(net1)
|
|
1137
|
+
ReLU<
|
|
1138
|
+
(child): Dense<input_channels=2, output_channels=2, has_bias=True>
|
|
1139
|
+
>
|
|
1044
1140
|
"""
|
|
1045
1141
|
if not isinstance(child_name, str):
|
|
1046
1142
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
@@ -1107,10 +1203,29 @@ class Cell(Cell_):
|
|
|
1107
1203
|
`init_parameters_data`, do not save these results.
|
|
1108
1204
|
|
|
1109
1205
|
Args:
|
|
1110
|
-
auto_parallel_mode (bool): If running in auto_parallel_mode. Default: False.
|
|
1206
|
+
auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
|
|
1111
1207
|
|
|
1112
1208
|
Returns:
|
|
1113
1209
|
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
|
|
1210
|
+
|
|
1211
|
+
Examples:
|
|
1212
|
+
>>> import mindspore as ms
|
|
1213
|
+
>>> from mindspore import Tensor, nn
|
|
1214
|
+
...
|
|
1215
|
+
>>> class Net(nn.Cell):
|
|
1216
|
+
... def __init__(self):
|
|
1217
|
+
... super(Net, self).__init__()
|
|
1218
|
+
... self.dense = nn.Dense(2, 2)
|
|
1219
|
+
...
|
|
1220
|
+
... def construct(self, x):
|
|
1221
|
+
... x = self.dense(x)
|
|
1222
|
+
... return x
|
|
1223
|
+
>>> net = Net()
|
|
1224
|
+
>>> print(net.init_parameters_data())
|
|
1225
|
+
{Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
|
|
1226
|
+
Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
|
|
1227
|
+
Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
|
|
1228
|
+
Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
|
|
1114
1229
|
"""
|
|
1115
1230
|
replace = dict()
|
|
1116
1231
|
|
|
@@ -1152,10 +1267,28 @@ class Cell(Cell_):
|
|
|
1152
1267
|
Gets the parameters dictionary of this cell.
|
|
1153
1268
|
|
|
1154
1269
|
Args:
|
|
1155
|
-
recurse (bool): Whether contains the parameters of subcells. Default: True.
|
|
1270
|
+
recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
|
|
1156
1271
|
|
|
1157
1272
|
Returns:
|
|
1158
1273
|
OrderedDict, return parameters dictionary.
|
|
1274
|
+
|
|
1275
|
+
Examples:
|
|
1276
|
+
>>> import mindspore as ms
|
|
1277
|
+
>>> from mindspore import Tensor, nn, Parameter
|
|
1278
|
+
...
|
|
1279
|
+
>>> class Net(nn.Cell):
|
|
1280
|
+
... def __init__(self):
|
|
1281
|
+
... super(Net, self).__init__()
|
|
1282
|
+
... self.dense = nn.Dense(2, 2)
|
|
1283
|
+
...
|
|
1284
|
+
... def construct(self, x):
|
|
1285
|
+
... x = self.dense(x)
|
|
1286
|
+
... return x
|
|
1287
|
+
>>> net = Net()
|
|
1288
|
+
>>> print(net.parameters_dict())
|
|
1289
|
+
OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
|
|
1290
|
+
requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
|
|
1291
|
+
requires_grad=True))])
|
|
1159
1292
|
"""
|
|
1160
1293
|
param_dict = OrderedDict()
|
|
1161
1294
|
for param in self.get_parameters(expand=recurse):
|
|
@@ -1167,7 +1300,7 @@ class Cell(Cell_):
|
|
|
1167
1300
|
Gets the parameters broadcast dictionary of this cell.
|
|
1168
1301
|
|
|
1169
1302
|
Args:
|
|
1170
|
-
recurse (bool): Whether contains the parameters of subcells. Default: True.
|
|
1303
|
+
recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
|
|
1171
1304
|
|
|
1172
1305
|
Returns:
|
|
1173
1306
|
OrderedDict, return parameters broadcast dictionary.
|
|
@@ -1185,11 +1318,11 @@ class Cell(Cell_):
|
|
|
1185
1318
|
Adds the `prefix` string to the names of parameters.
|
|
1186
1319
|
|
|
1187
1320
|
Args:
|
|
1188
|
-
prefix (str): The prefix string. Default: ''.
|
|
1189
|
-
recurse (bool): Whether contains the parameters of subcells. Default: True.
|
|
1321
|
+
prefix (str): The prefix string. Default: ``''`` .
|
|
1322
|
+
recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
|
|
1190
1323
|
"""
|
|
1191
1324
|
|
|
1192
|
-
Validator.
|
|
1325
|
+
Validator.check_str_and_none_by_regular(prefix)
|
|
1193
1326
|
for name, param in self.parameters_and_names(expand=recurse):
|
|
1194
1327
|
if prefix != '':
|
|
1195
1328
|
param.is_init = False
|
|
@@ -1205,7 +1338,7 @@ class Cell(Cell_):
|
|
|
1205
1338
|
|
|
1206
1339
|
Args:
|
|
1207
1340
|
prefix (str): The prefix string. Default: ''.
|
|
1208
|
-
recurse (bool): Whether contains the parameters of subcells. Default: True
|
|
1341
|
+
recurse (bool): Whether contains the parameters of subcells. Default: ``True``.
|
|
1209
1342
|
"""
|
|
1210
1343
|
|
|
1211
1344
|
Validator.check_str_by_regular(prefix)
|
|
@@ -1224,10 +1357,14 @@ class Cell(Cell_):
|
|
|
1224
1357
|
Returns a list of all trainable parameters.
|
|
1225
1358
|
|
|
1226
1359
|
Args:
|
|
1227
|
-
recurse (bool): Whether contains the trainable parameters of subcells. Default: True.
|
|
1360
|
+
recurse (bool): Whether contains the trainable parameters of subcells. Default: ``True`` .
|
|
1228
1361
|
|
|
1229
1362
|
Returns:
|
|
1230
1363
|
List, the list of trainable parameters.
|
|
1364
|
+
|
|
1365
|
+
Tutorial Examples:
|
|
1366
|
+
- `Model Training - Optimizer
|
|
1367
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#optimizer>`_
|
|
1231
1368
|
"""
|
|
1232
1369
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1233
1370
|
|
|
@@ -1239,7 +1376,7 @@ class Cell(Cell_):
|
|
|
1239
1376
|
Returns a list of all untrainable parameters.
|
|
1240
1377
|
|
|
1241
1378
|
Args:
|
|
1242
|
-
recurse (bool): Whether contains the untrainable parameters of subcells. Default: True.
|
|
1379
|
+
recurse (bool): Whether contains the untrainable parameters of subcells. Default: ``True`` .
|
|
1243
1380
|
|
|
1244
1381
|
Returns:
|
|
1245
1382
|
List, the list of untrainable parameters.
|
|
@@ -1251,25 +1388,58 @@ class Cell(Cell_):
|
|
|
1251
1388
|
"""
|
|
1252
1389
|
Returns an iterator over cell parameters.
|
|
1253
1390
|
|
|
1254
|
-
Yields parameters of this cell. If `expand` is true, yield parameters of this cell and all subcells.
|
|
1391
|
+
Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
|
|
1392
|
+
For more details about subcells, please see the example below.
|
|
1255
1393
|
|
|
1256
1394
|
Args:
|
|
1257
|
-
expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield
|
|
1258
|
-
that are direct members of this cell. Default: True.
|
|
1395
|
+
expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
|
|
1396
|
+
parameters that are direct members of this cell. Default: ``True`` .
|
|
1259
1397
|
|
|
1260
1398
|
Returns:
|
|
1261
1399
|
Iteration, all parameters at the cell.
|
|
1262
1400
|
|
|
1263
1401
|
Examples:
|
|
1264
|
-
>>>
|
|
1265
|
-
>>>
|
|
1266
|
-
>>>
|
|
1267
|
-
>>>
|
|
1268
|
-
...
|
|
1402
|
+
>>> import mindspore as ms
|
|
1403
|
+
>>> from mindspore import nn, ops, Tensor
|
|
1404
|
+
>>> import numpy as np
|
|
1405
|
+
>>> class TestNet(nn.Cell):
|
|
1406
|
+
... def __init__(self):
|
|
1407
|
+
... super().__init__()
|
|
1408
|
+
... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
|
|
1409
|
+
... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
|
|
1410
|
+
... def construct(self, x):
|
|
1411
|
+
... x += self.my_w1
|
|
1412
|
+
... x = ops.reshape(x, (16,)) - self.my_w2
|
|
1413
|
+
... return x
|
|
1414
|
+
>>> class TestNet2(nn.Cell):
|
|
1415
|
+
... def __init__(self):
|
|
1416
|
+
... super().__init__()
|
|
1417
|
+
... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
|
|
1418
|
+
... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
|
|
1419
|
+
... # also be gathered.
|
|
1420
|
+
... self.subcell = TestNet()
|
|
1421
|
+
... def construct(self, x):
|
|
1422
|
+
... x += self.my_w1
|
|
1423
|
+
... x = ops.reshape(x, (16,)) - self.my_w2
|
|
1424
|
+
... return x
|
|
1425
|
+
>>> net = TestNet2()
|
|
1426
|
+
>>> print([p for p in net.get_parameters(expand=True)])
|
|
1427
|
+
[Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
|
|
1428
|
+
shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
|
|
1429
|
+
requires_grad=True)]
|
|
1269
1430
|
"""
|
|
1270
1431
|
for _, param in self.parameters_and_names(expand=expand):
|
|
1271
1432
|
yield param
|
|
1272
1433
|
|
|
1434
|
+
# pylint: disable=missing-docstring
|
|
1435
|
+
def check_names_and_refresh_name(self):
|
|
1436
|
+
if not hasattr(self, "_params"):
|
|
1437
|
+
return
|
|
1438
|
+
all_name = [i.name for i in dict(self.parameters_and_names()).values()]
|
|
1439
|
+
if len(set(all_name)) < len(all_name):
|
|
1440
|
+
self.update_parameters_name()
|
|
1441
|
+
self.check_names()
|
|
1442
|
+
|
|
1273
1443
|
def check_names(self):
|
|
1274
1444
|
"""
|
|
1275
1445
|
Check the names of cell parameters.
|
|
@@ -1288,9 +1458,9 @@ class Cell(Cell_):
|
|
|
1288
1458
|
Includes the parameter's name and itself.
|
|
1289
1459
|
|
|
1290
1460
|
Args:
|
|
1291
|
-
name_prefix (str): Namespace. Default: ''.
|
|
1461
|
+
name_prefix (str): Namespace. Default: ``''`` .
|
|
1292
1462
|
expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
|
|
1293
|
-
that are direct members of this cell. Default: True.
|
|
1463
|
+
that are direct members of this cell. Default: ``True`` .
|
|
1294
1464
|
|
|
1295
1465
|
Returns:
|
|
1296
1466
|
Iteration, all the names and corresponding parameters in the cell.
|
|
@@ -1302,6 +1472,10 @@ class Cell(Cell_):
|
|
|
1302
1472
|
>>> for m in n.parameters_and_names():
|
|
1303
1473
|
... if m[0]:
|
|
1304
1474
|
... names.append(m[0])
|
|
1475
|
+
|
|
1476
|
+
Tutorial Examples:
|
|
1477
|
+
- `Building a Network - Model Parameters
|
|
1478
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/model.html#model-parameters>`_
|
|
1305
1479
|
"""
|
|
1306
1480
|
cells = []
|
|
1307
1481
|
if expand:
|
|
@@ -1313,7 +1487,7 @@ class Cell(Cell_):
|
|
|
1313
1487
|
for cell_name, cell in cells:
|
|
1314
1488
|
params = cell._params.items()
|
|
1315
1489
|
for par_name, par in params:
|
|
1316
|
-
if par.inited_param is not None:
|
|
1490
|
+
if par is not None and par.inited_param is not None:
|
|
1317
1491
|
par = par.inited_param
|
|
1318
1492
|
if par is not None and id(par) not in params_set:
|
|
1319
1493
|
params_set.add(id(par))
|
|
@@ -1328,8 +1502,8 @@ class Cell(Cell_):
|
|
|
1328
1502
|
Returns an iterator over all cells in the network, including the cell's name and itself.
|
|
1329
1503
|
|
|
1330
1504
|
Args:
|
|
1331
|
-
cells (str): Cells to iterate over. Default: None.
|
|
1332
|
-
name_prefix (str): Namespace. Default: ''.
|
|
1505
|
+
cells (str): Cells to iterate over. Default: ``None`` .
|
|
1506
|
+
name_prefix (str): Namespace. Default: ``''`` .
|
|
1333
1507
|
|
|
1334
1508
|
Returns:
|
|
1335
1509
|
Iteration, all the child cells and corresponding names in the cell.
|
|
@@ -1370,6 +1544,22 @@ class Cell(Cell_):
|
|
|
1370
1544
|
|
|
1371
1545
|
Returns:
|
|
1372
1546
|
Iteration, the immediate cells in the cell.
|
|
1547
|
+
|
|
1548
|
+
Examples:
|
|
1549
|
+
>>> import mindspore as ms
|
|
1550
|
+
>>> from mindspore import Tensor, nn
|
|
1551
|
+
...
|
|
1552
|
+
>>> class Net(nn.Cell):
|
|
1553
|
+
... def __init__(self):
|
|
1554
|
+
... super(Net, self).__init__()
|
|
1555
|
+
... self.dense = nn.Dense(2, 2)
|
|
1556
|
+
...
|
|
1557
|
+
... def construct(self, x):
|
|
1558
|
+
... x = self.dense(x)
|
|
1559
|
+
... return x
|
|
1560
|
+
>>> net = Net()
|
|
1561
|
+
>>> print(net.cells())
|
|
1562
|
+
odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
|
|
1373
1563
|
"""
|
|
1374
1564
|
return self.name_cells().values()
|
|
1375
1565
|
|
|
@@ -1415,6 +1605,22 @@ class Cell(Cell_):
|
|
|
1415
1605
|
|
|
1416
1606
|
Returns:
|
|
1417
1607
|
Dict, all the child cells and corresponding names in the cell.
|
|
1608
|
+
|
|
1609
|
+
Examples:
|
|
1610
|
+
>>> import mindspore as ms
|
|
1611
|
+
>>> from mindspore import Tensor, nn
|
|
1612
|
+
...
|
|
1613
|
+
>>> class Net(nn.Cell):
|
|
1614
|
+
... def __init__(self):
|
|
1615
|
+
... super(Net, self).__init__()
|
|
1616
|
+
... self.dense = nn.Dense(2, 2)
|
|
1617
|
+
...
|
|
1618
|
+
... def construct(self, x):
|
|
1619
|
+
... x = self.dense(x)
|
|
1620
|
+
... return x
|
|
1621
|
+
>>> net = Net()
|
|
1622
|
+
>>> print(net.name_cells())
|
|
1623
|
+
OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
|
|
1418
1624
|
"""
|
|
1419
1625
|
value_set = set()
|
|
1420
1626
|
cells = OrderedDict()
|
|
@@ -1430,13 +1636,8 @@ class Cell(Cell_):
|
|
|
1430
1636
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
|
1431
1637
|
if "fp32" in flags and flags.get("fp32", False):
|
|
1432
1638
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
"""Add mixed precision flag to each cell"""
|
|
1436
|
-
if "fp16" in flags and flags.get("fp16", False):
|
|
1437
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
|
1438
|
-
if "fp32" in flags and flags.get("fp32", False):
|
|
1439
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1639
|
+
if "bf16" in flags and flags.get("bf16", False):
|
|
1640
|
+
Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
|
|
1440
1641
|
|
|
1441
1642
|
def apply(self, fn):
|
|
1442
1643
|
"""
|
|
@@ -1478,7 +1679,24 @@ class Cell(Cell_):
|
|
|
1478
1679
|
|
|
1479
1680
|
Args:
|
|
1480
1681
|
flags (dict): Network configuration information, currently it is used for the binding of network and
|
|
1481
|
-
dataset. Users can also customize network attributes by this parameter.
|
|
1682
|
+
dataset. Users can also customize network attributes by this parameter.
|
|
1683
|
+
|
|
1684
|
+
Examples:
|
|
1685
|
+
>>> import mindspore as ms
|
|
1686
|
+
>>> from mindspore import Tensor, nn
|
|
1687
|
+
...
|
|
1688
|
+
>>> class Net(nn.Cell):
|
|
1689
|
+
... def __init__(self):
|
|
1690
|
+
... super(Net, self).__init__()
|
|
1691
|
+
... self.relu = nn.ReLU()
|
|
1692
|
+
...
|
|
1693
|
+
... def construct(self, x):
|
|
1694
|
+
... x = self.relu(x)
|
|
1695
|
+
... return x
|
|
1696
|
+
>>> net = Net()
|
|
1697
|
+
>>> net.add_flags(sink_mode=True)
|
|
1698
|
+
>>> print(net.sink_mode)
|
|
1699
|
+
True
|
|
1482
1700
|
"""
|
|
1483
1701
|
if not hasattr(self, "_func_graph_flags"):
|
|
1484
1702
|
self._func_graph_flags = {}
|
|
@@ -1493,10 +1711,26 @@ class Cell(Cell_):
|
|
|
1493
1711
|
|
|
1494
1712
|
Args:
|
|
1495
1713
|
flags (dict): Network configuration information, currently it is used for the binding of network and
|
|
1496
|
-
dataset. Users can also customize network attributes by this parameter.
|
|
1714
|
+
dataset. Users can also customize network attributes by this parameter.
|
|
1715
|
+
|
|
1716
|
+
Examples:
|
|
1717
|
+
>>> import mindspore as ms
|
|
1718
|
+
>>> from mindspore import Tensor, nn
|
|
1719
|
+
...
|
|
1720
|
+
>>> class Net(nn.Cell):
|
|
1721
|
+
... def __init__(self):
|
|
1722
|
+
... super(Net, self).__init__()
|
|
1723
|
+
... self.relu = nn.ReLU()
|
|
1724
|
+
...
|
|
1725
|
+
... def construct(self, x):
|
|
1726
|
+
... x = self.relu(x)
|
|
1727
|
+
... return x
|
|
1728
|
+
>>> net = Net()
|
|
1729
|
+
>>> net.add_flags_recursive(sink_mode=True)
|
|
1730
|
+
>>> print(net.sink_mode)
|
|
1731
|
+
True
|
|
1497
1732
|
"""
|
|
1498
1733
|
self.add_flags(**flags)
|
|
1499
|
-
self._add_mixed_precision_flag_recursive(**flags)
|
|
1500
1734
|
for cell in self.cells():
|
|
1501
1735
|
cell.add_flags_recursive(**flags)
|
|
1502
1736
|
return self
|
|
@@ -1508,17 +1742,28 @@ class Cell(Cell_):
|
|
|
1508
1742
|
def get_flags(self):
|
|
1509
1743
|
"""
|
|
1510
1744
|
Get the self_defined attributes of the cell, which can be added by `add_flags` method.
|
|
1745
|
+
|
|
1746
|
+
Examples:
|
|
1747
|
+
>>> import mindspore as ms
|
|
1748
|
+
>>> from mindspore import Tensor, nn
|
|
1749
|
+
...
|
|
1750
|
+
>>> class Net(nn.Cell):
|
|
1751
|
+
... def __init__(self):
|
|
1752
|
+
... super(Net, self).__init__()
|
|
1753
|
+
... self.relu = nn.ReLU()
|
|
1754
|
+
...
|
|
1755
|
+
... def construct(self, x):
|
|
1756
|
+
... x = self.relu(x)
|
|
1757
|
+
... return x
|
|
1758
|
+
>>> net = Net()
|
|
1759
|
+
>>> net.add_flags(sink_mode=True)
|
|
1760
|
+
>>> print(net.get_flags())
|
|
1761
|
+
{'sink_mode':True}
|
|
1511
1762
|
"""
|
|
1512
1763
|
if not hasattr(self, "_func_graph_flags"):
|
|
1513
1764
|
self._func_graph_flags = {}
|
|
1514
1765
|
return self._func_graph_flags
|
|
1515
1766
|
|
|
1516
|
-
def _set_mixed_precision_type_recursive(self, mixed_type):
|
|
1517
|
-
"""Set mixed precision type to each cell"""
|
|
1518
|
-
Cell_.set_mixed_precision_type(self, mixed_type)
|
|
1519
|
-
for cell in self.cells():
|
|
1520
|
-
cell._set_mixed_precision_type_recursive(mixed_type)
|
|
1521
|
-
|
|
1522
1767
|
def to_float(self, dst_type):
|
|
1523
1768
|
"""
|
|
1524
1769
|
Add cast on all inputs of cell and child cells to run with certain float type.
|
|
@@ -1531,13 +1776,13 @@ class Cell(Cell_):
|
|
|
1531
1776
|
|
|
1532
1777
|
Args:
|
|
1533
1778
|
dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
|
|
1534
|
-
dst_type can be `mstype.float16` or `mstype.
|
|
1779
|
+
dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
|
|
1535
1780
|
|
|
1536
1781
|
Returns:
|
|
1537
1782
|
Cell, the cell itself.
|
|
1538
1783
|
|
|
1539
1784
|
Raises:
|
|
1540
|
-
ValueError: If dst_type is not mstype.float32
|
|
1785
|
+
ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
|
|
1541
1786
|
|
|
1542
1787
|
Supported Platforms:
|
|
1543
1788
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1549,19 +1794,15 @@ class Cell(Cell_):
|
|
|
1549
1794
|
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
|
1550
1795
|
>>> net.to_float(mstype.float16)
|
|
1551
1796
|
Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
|
|
1552
|
-
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=
|
|
1553
|
-
"""
|
|
1554
|
-
if dst_type not in (mstype.float16, mstype.float32):
|
|
1555
|
-
raise ValueError("For 'to_float', the argument 'dst_type' must be float32 or
|
|
1556
|
-
"but got {}.".format(dst_type))
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
self.to_float_fp16 = True
|
|
1560
|
-
else:
|
|
1561
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1562
|
-
self.to_float_fp16 = False
|
|
1563
|
-
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
|
|
1797
|
+
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
|
|
1798
|
+
"""
|
|
1799
|
+
if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
|
|
1800
|
+
raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
|
|
1801
|
+
"mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
|
|
1802
|
+
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
|
|
1803
|
+
'bf16': dst_type == mstype.bfloat16}
|
|
1564
1804
|
self._add_init_args(**flags)
|
|
1805
|
+
self.add_flags_recursive(**flags)
|
|
1565
1806
|
return self
|
|
1566
1807
|
|
|
1567
1808
|
def set_boost(self, boost_type):
|
|
@@ -1570,7 +1811,7 @@ class Cell(Cell_):
|
|
|
1570
1811
|
accelerate the algorithm in the algorithm library.
|
|
1571
1812
|
|
|
1572
1813
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1573
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.
|
|
1814
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.2/mindspore/python/mindspore/boost>`_.
|
|
1574
1815
|
|
|
1575
1816
|
Note:
|
|
1576
1817
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1594,12 +1835,12 @@ class Cell(Cell_):
|
|
|
1594
1835
|
def set_grad(self, requires_grad=True):
|
|
1595
1836
|
"""
|
|
1596
1837
|
Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network requires
|
|
1597
|
-
gradients. If true, the backward network needed to compute the gradients will be generated when the forward
|
|
1838
|
+
gradients. If ``true`` , the backward network needed to compute the gradients will be generated when the forward
|
|
1598
1839
|
network is executed.
|
|
1599
1840
|
|
|
1600
1841
|
Args:
|
|
1601
1842
|
requires_grad (bool): Specifies if the net need to grad, if it is
|
|
1602
|
-
true, the cell will construct backward network in pynative mode. Default: True.
|
|
1843
|
+
``true`` , the cell will construct backward network in pynative mode. Default: ``True`` .
|
|
1603
1844
|
|
|
1604
1845
|
Returns:
|
|
1605
1846
|
Cell, the cell itself.
|
|
@@ -1620,15 +1861,19 @@ class Cell(Cell_):
|
|
|
1620
1861
|
When execute function Model.eval(), framework will call Cell.set_train(False).
|
|
1621
1862
|
|
|
1622
1863
|
Args:
|
|
1623
|
-
mode (bool): Specifies whether the model is training. Default: True.
|
|
1864
|
+
mode (bool): Specifies whether the model is training. Default: ``True`` .
|
|
1624
1865
|
|
|
1625
1866
|
Returns:
|
|
1626
1867
|
Cell, the cell itself.
|
|
1868
|
+
|
|
1869
|
+
Tutorial Examples:
|
|
1870
|
+
- `Model Training - Implementing Training and Evaluation
|
|
1871
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#training-and-evaluation>`_
|
|
1627
1872
|
"""
|
|
1628
|
-
if mode
|
|
1629
|
-
self._phase = 'predict'
|
|
1630
|
-
else:
|
|
1873
|
+
if mode:
|
|
1631
1874
|
self._phase = 'train'
|
|
1875
|
+
else:
|
|
1876
|
+
self._phase = 'predict'
|
|
1632
1877
|
self.add_flags_recursive(training=mode)
|
|
1633
1878
|
return self
|
|
1634
1879
|
|
|
@@ -1637,7 +1882,7 @@ class Cell(Cell_):
|
|
|
1637
1882
|
Set parameter broadcast mode for this cell.
|
|
1638
1883
|
|
|
1639
1884
|
Args:
|
|
1640
|
-
mode (bool): Specifies whether the mode is parameter broadcast. Default: True.
|
|
1885
|
+
mode (bool): Specifies whether the mode is parameter broadcast. Default: ``True`` .
|
|
1641
1886
|
"""
|
|
1642
1887
|
self.add_flags_recursive(broadcast_flag=mode)
|
|
1643
1888
|
return self
|
|
@@ -1657,16 +1902,27 @@ class Cell(Cell_):
|
|
|
1657
1902
|
|
|
1658
1903
|
Args:
|
|
1659
1904
|
jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
|
|
1905
|
+
|
|
1906
|
+
Examples:
|
|
1907
|
+
>>> import mindspore as ms
|
|
1908
|
+
>>> from mindspore import Tensor, nn
|
|
1909
|
+
...
|
|
1910
|
+
>>> class Net(nn.Cell):
|
|
1911
|
+
... def __init__(self):
|
|
1912
|
+
... super(Net, self).__init__()
|
|
1913
|
+
... self.relu = nn.ReLU()
|
|
1914
|
+
...
|
|
1915
|
+
... def construct(self, x):
|
|
1916
|
+
... x = self.relu(x)
|
|
1917
|
+
... return x
|
|
1918
|
+
>>> net = Net()
|
|
1919
|
+
>>> jitconfig = ms.JitConfig()
|
|
1920
|
+
>>> net.set_jit_config(jitconfig)
|
|
1660
1921
|
"""
|
|
1661
1922
|
if self._jit_config_dict:
|
|
1662
1923
|
logger.warning("For Cell, jit config can only be set once, ignore this setting.")
|
|
1663
1924
|
else:
|
|
1664
1925
|
self._jit_config_dict = jit_config.jit_config_dict
|
|
1665
|
-
enable_ge = os.getenv("MS_ENABLE_GE") == '1'
|
|
1666
|
-
enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
|
|
1667
|
-
if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
|
|
1668
|
-
raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
|
|
1669
|
-
format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
|
1670
1926
|
|
|
1671
1927
|
def flatten_weights(self, fusion_size=0):
|
|
1672
1928
|
"""
|
|
@@ -1679,7 +1935,7 @@ class Cell(Cell_):
|
|
|
1679
1935
|
to limit the maximum memory chunk size.
|
|
1680
1936
|
|
|
1681
1937
|
Args:
|
|
1682
|
-
fusion_size (int): Maximum memory chunk size in bytes, 0 for unlimited. Default: 0.
|
|
1938
|
+
fusion_size (int): Maximum memory chunk size in bytes, ``0`` for unlimited. Default: ``0`` .
|
|
1683
1939
|
"""
|
|
1684
1940
|
if fusion_size < 0:
|
|
1685
1941
|
raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
|
|
@@ -1718,9 +1974,7 @@ class Cell(Cell_):
|
|
|
1718
1974
|
Examples:
|
|
1719
1975
|
>>> import numpy as np
|
|
1720
1976
|
>>> import mindspore as ms
|
|
1721
|
-
>>>
|
|
1722
|
-
>>> from mindspore import Tensor
|
|
1723
|
-
>>> from mindspore.ops import GradOperation
|
|
1977
|
+
>>> from mindspore import Tensor, nn, ops
|
|
1724
1978
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
1725
1979
|
>>> def forward_pre_hook_fn(cell_id, inputs):
|
|
1726
1980
|
... print("forward inputs: ", inputs)
|
|
@@ -1735,7 +1989,7 @@ class Cell(Cell_):
|
|
|
1735
1989
|
... x = x + x
|
|
1736
1990
|
... x = self.mul(x, y)
|
|
1737
1991
|
... return x
|
|
1738
|
-
>>> grad = GradOperation(get_all=True)
|
|
1992
|
+
>>> grad = ops.GradOperation(get_all=True)
|
|
1739
1993
|
>>> net = Net()
|
|
1740
1994
|
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
|
|
1741
1995
|
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
|
|
@@ -1820,9 +2074,7 @@ class Cell(Cell_):
|
|
|
1820
2074
|
Examples:
|
|
1821
2075
|
>>> import numpy as np
|
|
1822
2076
|
>>> import mindspore as ms
|
|
1823
|
-
>>>
|
|
1824
|
-
>>> from mindspore import Tensor
|
|
1825
|
-
>>> from mindspore.ops import GradOperation
|
|
2077
|
+
>>> from mindspore import Tensor, nn, ops
|
|
1826
2078
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
1827
2079
|
>>> def forward_hook_fn(cell_id, inputs, output):
|
|
1828
2080
|
... print("forward inputs: ", inputs)
|
|
@@ -1838,7 +2090,7 @@ class Cell(Cell_):
|
|
|
1838
2090
|
... x = x + x
|
|
1839
2091
|
... x = self.mul(x, y)
|
|
1840
2092
|
... return x
|
|
1841
|
-
>>> grad = GradOperation(get_all=True)
|
|
2093
|
+
>>> grad = ops.GradOperation(get_all=True)
|
|
1842
2094
|
>>> net = Net()
|
|
1843
2095
|
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
|
|
1844
2096
|
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
|
|
@@ -1922,9 +2174,7 @@ class Cell(Cell_):
|
|
|
1922
2174
|
Examples:
|
|
1923
2175
|
>>> import numpy as np
|
|
1924
2176
|
>>> import mindspore as ms
|
|
1925
|
-
>>>
|
|
1926
|
-
>>> from mindspore import Tensor
|
|
1927
|
-
>>> from mindspore.ops import GradOperation
|
|
2177
|
+
>>> from mindspore import Tensor, nn, ops
|
|
1928
2178
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
1929
2179
|
>>> def backward_hook_fn(cell_id, grad_input, grad_output):
|
|
1930
2180
|
... print("backward input: ", grad_input)
|
|
@@ -1940,7 +2190,7 @@ class Cell(Cell_):
|
|
|
1940
2190
|
... x = x + x
|
|
1941
2191
|
... x = self.relu(x)
|
|
1942
2192
|
... return x
|
|
1943
|
-
>>> grad = GradOperation(get_all=True)
|
|
2193
|
+
>>> grad = ops.GradOperation(get_all=True)
|
|
1944
2194
|
>>> net = Net()
|
|
1945
2195
|
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
|
|
1946
2196
|
backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
|
|
@@ -1966,12 +2216,13 @@ class Cell(Cell_):
|
|
|
1966
2216
|
handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
|
|
1967
2217
|
return handle
|
|
1968
2218
|
|
|
1969
|
-
def _backward_hook_construct(self, *inputs):
|
|
2219
|
+
def _backward_hook_construct(self, *inputs, **kwargs):
|
|
1970
2220
|
"""
|
|
1971
2221
|
Backward hook construct method to replace original construct method.
|
|
1972
2222
|
|
|
1973
2223
|
Args:
|
|
1974
2224
|
inputs: The input objects of Cell object.
|
|
2225
|
+
kwargs (dict): Dictionary of variable keyword parameters.
|
|
1975
2226
|
|
|
1976
2227
|
Returns:
|
|
1977
2228
|
- **outputs** - The output objects of Cell object.
|
|
@@ -1983,10 +2234,11 @@ class Cell(Cell_):
|
|
|
1983
2234
|
inputs = self._cell_backward_hook(inputs)
|
|
1984
2235
|
else:
|
|
1985
2236
|
inputs = self._cell_backward_hook(*inputs)
|
|
2237
|
+
inputs = (inputs,)
|
|
1986
2238
|
if isinstance(inputs, tuple):
|
|
1987
|
-
outputs = self.construct(*inputs)
|
|
2239
|
+
outputs = self.construct(*inputs, **kwargs)
|
|
1988
2240
|
else:
|
|
1989
|
-
outputs = self.construct(inputs)
|
|
2241
|
+
outputs = self.construct(inputs, **kwargs)
|
|
1990
2242
|
outputs = self._cell_backward_hook(outputs)
|
|
1991
2243
|
return outputs
|
|
1992
2244
|
|
|
@@ -2000,23 +2252,16 @@ class Cell(Cell_):
|
|
|
2000
2252
|
It is only supported in graph mode.
|
|
2001
2253
|
|
|
2002
2254
|
Args:
|
|
2003
|
-
recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
|
|
2255
|
+
recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
|
|
2004
2256
|
init_in_server (bool): Whether trainable parameters updated by parameter server are
|
|
2005
|
-
initialized on server. Default: False.
|
|
2257
|
+
initialized on server. Default: ``False`` .
|
|
2006
2258
|
"""
|
|
2007
2259
|
params = self.trainable_params(recurse)
|
|
2008
2260
|
for param in params:
|
|
2009
2261
|
param.set_param_ps(init_in_server)
|
|
2010
2262
|
|
|
2263
|
+
@deprecated("1.8", "set_param_fl")
|
|
2011
2264
|
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
|
|
2012
|
-
"""
|
|
2013
|
-
Set the way of parameter and server interaction.
|
|
2014
|
-
|
|
2015
|
-
Args:
|
|
2016
|
-
push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
|
|
2017
|
-
pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
|
|
2018
|
-
requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
|
|
2019
|
-
"""
|
|
2020
2265
|
params = self.parameters_and_names()
|
|
2021
2266
|
for param in params:
|
|
2022
2267
|
param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
|
|
@@ -2031,7 +2276,7 @@ class Cell(Cell_):
|
|
|
2031
2276
|
|
|
2032
2277
|
Args:
|
|
2033
2278
|
fusion_type (int): The value of `comm_fusion`.
|
|
2034
|
-
recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
|
|
2279
|
+
recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
|
|
2035
2280
|
"""
|
|
2036
2281
|
Validator.check_non_negative_int(fusion_type)
|
|
2037
2282
|
for param in self.trainable_params(recurse):
|
|
@@ -2118,10 +2363,10 @@ class Cell(Cell_):
|
|
|
2118
2363
|
|
|
2119
2364
|
Args:
|
|
2120
2365
|
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
|
|
2121
|
-
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
|
|
2366
|
+
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: ``True`` .
|
|
2122
2367
|
parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
|
|
2123
2368
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
2124
|
-
Default: False.
|
|
2369
|
+
Default: ``False`` .
|
|
2125
2370
|
"""
|
|
2126
2371
|
self._recompute()
|
|
2127
2372
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
@@ -2133,7 +2378,7 @@ class Cell(Cell_):
|
|
|
2133
2378
|
"are not support recomputation in pipeline parallel.")
|
|
2134
2379
|
elif context.get_auto_parallel_context("pipeline_stages") == 1:
|
|
2135
2380
|
self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
|
|
2136
|
-
if 'recompute_slice_activation' in kwargs
|
|
2381
|
+
if 'recompute_slice_activation' in kwargs:
|
|
2137
2382
|
self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
|
|
2138
2383
|
|
|
2139
2384
|
for key, _ in kwargs.items():
|
|
@@ -2217,19 +2462,29 @@ class Cell(Cell_):
|
|
|
2217
2462
|
"""
|
|
2218
2463
|
if not isinstance(net_input, Tensor):
|
|
2219
2464
|
raise TypeError(
|
|
2220
|
-
f"
|
|
2465
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must be Tensor, "
|
|
2466
|
+
f"but got {type(net_input)}.")
|
|
2467
|
+
is_param_set_input = isinstance(set_input, Parameter)
|
|
2468
|
+
is_param_net_input = isinstance(net_input, Parameter)
|
|
2469
|
+
if (is_param_set_input and not is_param_net_input) or (is_param_net_input and not is_param_set_input):
|
|
2470
|
+
raise TypeError(
|
|
2471
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
|
|
2472
|
+
f"as network's input, but got 'set_inputs': {type(set_input)} and network's input: {type(net_input)}.")
|
|
2221
2473
|
if set_input.dtype != net_input.dtype:
|
|
2222
|
-
raise
|
|
2223
|
-
f"
|
|
2224
|
-
f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
|
2225
|
-
if
|
|
2226
|
-
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2474
|
+
raise TypeError(
|
|
2475
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs',the dtype of {index + 1}th input must be the same "
|
|
2476
|
+
f"as network's input, but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
|
2477
|
+
if -2 not in set_input.shape:
|
|
2478
|
+
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
|
2479
|
+
raise ValueError(
|
|
2480
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs',the dims of {index + 1}th input must be the "
|
|
2481
|
+
f"same as network's input, but got 'set_inputs': {set_input.dim()} and network's input: "
|
|
2482
|
+
f"{net_input.dim()}.")
|
|
2483
|
+
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
|
2484
|
+
raise ValueError(
|
|
2485
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs',the shape of {index + 1}th input must be the "
|
|
2486
|
+
f"same as network's input, but got 'set_inputs': {set_input.shape} and network's input: "
|
|
2487
|
+
f"{net_input.shape}.")
|
|
2233
2488
|
|
|
2234
2489
|
def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
|
|
2235
2490
|
"""
|
|
@@ -2241,22 +2496,61 @@ class Cell(Cell_):
|
|
|
2241
2496
|
set_inputs_len = len(set_inputs)
|
|
2242
2497
|
net_inputs_len = len(net_inputs)
|
|
2243
2498
|
if set_inputs_len != net_inputs_len:
|
|
2244
|
-
raise ValueError("The length of 'set_inputs' must be equal to network's
|
|
2245
|
-
f"but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
|
|
2499
|
+
raise ValueError("The length of 'set_inputs' or tuple(list) in 'set_inputs' must be equal to network's "
|
|
2500
|
+
f"inputs, but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
|
|
2246
2501
|
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
2247
2502
|
if isinstance(set_input, Tensor):
|
|
2248
2503
|
self._check_dynamic_tensor(set_input, net_input, index)
|
|
2249
2504
|
elif isinstance(set_input, (tuple, list)):
|
|
2250
2505
|
if not isinstance(net_input, (tuple, list)):
|
|
2251
2506
|
raise TypeError(
|
|
2252
|
-
f"The {index + 1}th input type of 'set_inputs' must be tuple or
|
|
2253
|
-
f"but got {type(net_input)}.")
|
|
2507
|
+
f"The {index + 1}th input type of 'set_inputs' or tuple(list) in 'set_inputs' must be tuple or "
|
|
2508
|
+
f"list, but got {type(net_input)}.")
|
|
2254
2509
|
self._check_compile_dynamic_shape(set_input, net_input)
|
|
2255
2510
|
else:
|
|
2511
|
+
if context._get_mode() == context.PYNATIVE_MODE and set_input is None:
|
|
2512
|
+
continue
|
|
2256
2513
|
if net_input != set_input:
|
|
2257
2514
|
raise ValueError(
|
|
2258
|
-
f"The {index + 1}th input of 'set_inputs' must be the same with
|
|
2259
|
-
f"set_inputs: {set_input} and network's input: {net_input}.")
|
|
2515
|
+
f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
|
|
2516
|
+
f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
|
|
2517
|
+
|
|
2518
|
+
def _run_tracefunc(self, *args, **kwargs):
|
|
2519
|
+
""" Run Packed Cell in Pack."""
|
|
2520
|
+
args = self._mixed_precision_cast(args)
|
|
2521
|
+
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
|
2522
|
+
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
|
2523
|
+
expander = PackExpander.get_instance()
|
|
2524
|
+
args = expander.begin_subgraph(self, *args)
|
|
2525
|
+
args = [_convert_tensor(a) for a in args]
|
|
2526
|
+
output = self._run_construct(args, kwargs)
|
|
2527
|
+
ret = expander.end_subgraph(self, output)
|
|
2528
|
+
output = _convert_tensor(ret)
|
|
2529
|
+
else:
|
|
2530
|
+
with _SetMixedPrecision(self):
|
|
2531
|
+
output = self._run_construct(args, kwargs)
|
|
2532
|
+
return output
|
|
2533
|
+
|
|
2534
|
+
def _mixed_precision_cast(self, inputs):
|
|
2535
|
+
mixed_type = self.get_mixed_precision_type()
|
|
2536
|
+
if mixed_type == MixedPrecisionType.NOTSET:
|
|
2537
|
+
return inputs
|
|
2538
|
+
if mixed_type == MixedPrecisionType.FP16:
|
|
2539
|
+
cast_type = mstype.float16
|
|
2540
|
+
elif mixed_type == MixedPrecisionType.BF16:
|
|
2541
|
+
cast_type = mstype.bfloat16
|
|
2542
|
+
else:
|
|
2543
|
+
cast_type = mstype.float32
|
|
2544
|
+
cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
|
|
2545
|
+
return cast_inputs
|
|
2546
|
+
|
|
2547
|
+
def _get_attr_from_cell(self, network):
|
|
2548
|
+
if not isinstance(network, Cell):
|
|
2549
|
+
return
|
|
2550
|
+
if hasattr(network, "jit_config_dict"):
|
|
2551
|
+
self._jit_config_dict = network.jit_config_dict
|
|
2552
|
+
if hasattr(network, "_amp_level"):
|
|
2553
|
+
self._amp_level = getattr(network, "_amp_level")
|
|
2260
2554
|
|
|
2261
2555
|
|
|
2262
2556
|
class GraphCell(Cell):
|
|
@@ -2271,11 +2565,11 @@ class GraphCell(Cell):
|
|
|
2271
2565
|
params_init (dict): Parameters need to be inited in the graph.
|
|
2272
2566
|
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
|
|
2273
2567
|
If the parameter exists in the graph according to the name, update it's value.
|
|
2274
|
-
If the parameter does not exist, ignore it. Default: None.
|
|
2568
|
+
If the parameter does not exist, ignore it. Default: ``None`` .
|
|
2275
2569
|
obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
|
|
2276
2570
|
used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
|
|
2277
2571
|
a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
|
|
2278
|
-
provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: None.
|
|
2572
|
+
provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
|
|
2279
2573
|
|
|
2280
2574
|
Raises:
|
|
2281
2575
|
TypeError: If the `graph` is not a FuncGraph.
|