mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -29,18 +29,42 @@ from mindspore import log as logger
|
|
|
29
29
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
30
30
|
from mindspore.common.hook_handle import HookHandle
|
|
31
31
|
from mindspore.context import ParallelMode
|
|
32
|
-
from mindspore.ops.composite import Shard
|
|
33
32
|
from mindspore import context
|
|
34
33
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
35
|
-
from mindspore
|
|
34
|
+
from mindspore import _checkparam as Validator
|
|
36
35
|
from mindspore.common import dtype as mstype
|
|
37
36
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
+
from mindspore.common.api import _generate_branch_control_input
|
|
38
38
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
39
39
|
from mindspore.common.tensor import Tensor
|
|
40
40
|
from mindspore.ops.operations import Cast
|
|
41
41
|
from mindspore.ops.primitive import Primitive
|
|
42
42
|
from mindspore.ops.operations import _inner_ops as inner
|
|
43
|
-
from mindspore.parallel.
|
|
43
|
+
from mindspore.parallel.shard import Shard
|
|
44
|
+
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _check_args(args):
|
|
48
|
+
"""Check the input args's type"""
|
|
49
|
+
index = 1
|
|
50
|
+
for item in args:
|
|
51
|
+
if isinstance(item, Tensor) and item.has_init:
|
|
52
|
+
item.init_data()
|
|
53
|
+
elif isinstance(item, numpy.ndarray):
|
|
54
|
+
suffix = "th"
|
|
55
|
+
if index == 1:
|
|
56
|
+
suffix = "st"
|
|
57
|
+
elif index == 2:
|
|
58
|
+
suffix = "nd"
|
|
59
|
+
elif index == 3:
|
|
60
|
+
suffix = "rd"
|
|
61
|
+
|
|
62
|
+
input_index = str(index) + suffix
|
|
63
|
+
raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
|
|
64
|
+
f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
|
|
65
|
+
f"), and tuple or list containing only these types, and dict whose values are these "
|
|
66
|
+
f"types, but the {input_index} arg type is {type(item)}.")
|
|
67
|
+
index += 1
|
|
44
68
|
|
|
45
69
|
|
|
46
70
|
class Cell(Cell_):
|
|
@@ -54,11 +78,14 @@ class Cell(Cell_):
|
|
|
54
78
|
PYNATIVE_MODE (dynamic graph mode).
|
|
55
79
|
|
|
56
80
|
Args:
|
|
57
|
-
auto_prefix (bool): Whether to automatically generate NameSpace for Cell and its
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
81
|
+
auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
|
|
82
|
+
affects the names of parameters in the `Cell`. If set to True, the parameter name will be
|
|
83
|
+
automatically prefixed, otherwise not. In general, the backbone network should be set to True,
|
|
84
|
+
otherwise the duplicate name problem will appear. The cell to train the backbone network, such as
|
|
85
|
+
optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to False, otherwise the
|
|
86
|
+
parameter name in backbone will be changed by mistake. Default: True.
|
|
87
|
+
flags (dict, optional): Network configuration information, currently it is used for the binding of network
|
|
88
|
+
and dataset. Users can also customize network attributes by this parameter. Default: None.
|
|
62
89
|
|
|
63
90
|
Supported Platforms:
|
|
64
91
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -84,12 +111,11 @@ class Cell(Cell_):
|
|
|
84
111
|
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
|
|
85
112
|
"""
|
|
86
113
|
|
|
87
|
-
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '
|
|
88
|
-
'
|
|
89
|
-
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
|
|
114
|
+
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
115
|
+
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
|
|
90
116
|
'_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
|
|
91
117
|
'_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
|
92
|
-
'_attr_synced', 'pynative', 'requires_grad', '
|
|
118
|
+
'_attr_synced', 'pynative', 'requires_grad', 'cell_type']
|
|
93
119
|
|
|
94
120
|
def __init__(self, auto_prefix=True, flags=None):
|
|
95
121
|
Cell_.__init__(self, self._cell_tag)
|
|
@@ -123,10 +149,6 @@ class Cell(Cell_):
|
|
|
123
149
|
if os.getenv('GC_COLLECT_IN_CELL') == '1':
|
|
124
150
|
gc.collect()
|
|
125
151
|
|
|
126
|
-
self._construct_inputs_num = 0
|
|
127
|
-
self._construct_inputs_names = []
|
|
128
|
-
self._auto_parallel_mode = False
|
|
129
|
-
self._parallel_inputs_run = None
|
|
130
152
|
if flags:
|
|
131
153
|
self.add_flags(**flags)
|
|
132
154
|
self._bprop_debug = False
|
|
@@ -136,8 +158,8 @@ class Cell(Cell_):
|
|
|
136
158
|
self._enable_forward_hook = False
|
|
137
159
|
self._enable_backward_hook = False
|
|
138
160
|
self._cell_backward_hook = None
|
|
161
|
+
self._is_recursion_hook = False
|
|
139
162
|
self.cell_type = None
|
|
140
|
-
self._auto_parallel_compile_and_run = False
|
|
141
163
|
self.cast = Cast()
|
|
142
164
|
self._has_config_recompute = False
|
|
143
165
|
self._user_parameters = []
|
|
@@ -145,6 +167,7 @@ class Cell(Cell_):
|
|
|
145
167
|
self.saved_dynamic_shape = None
|
|
146
168
|
self._jit_config_dict = dict()
|
|
147
169
|
self.grad_ops_label = False
|
|
170
|
+
self.to_float_fp16 = False
|
|
148
171
|
|
|
149
172
|
def __getstate__(self):
|
|
150
173
|
base = Cell_.__getstate__(self)
|
|
@@ -156,6 +179,9 @@ class Cell(Cell_):
|
|
|
156
179
|
self.__dict__ = dict_
|
|
157
180
|
self._attr_synced = False
|
|
158
181
|
|
|
182
|
+
def __bool__(self):
|
|
183
|
+
return True
|
|
184
|
+
|
|
159
185
|
@property
|
|
160
186
|
def _cell_tag(self):
|
|
161
187
|
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
|
@@ -310,8 +336,6 @@ class Cell(Cell_):
|
|
|
310
336
|
if '_params' in self.__dict__:
|
|
311
337
|
params = self.__dict__['_params']
|
|
312
338
|
if name in params:
|
|
313
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
314
|
-
return self.cast_param(params[name])
|
|
315
339
|
return params[name]
|
|
316
340
|
if '_cells' in self.__dict__:
|
|
317
341
|
cells = self.__dict__['_cells']
|
|
@@ -320,27 +344,23 @@ class Cell(Cell_):
|
|
|
320
344
|
if '_tensor_list' in self.__dict__:
|
|
321
345
|
tensor_list = self.__dict__['_tensor_list']
|
|
322
346
|
if name in tensor_list:
|
|
323
|
-
return
|
|
347
|
+
return tensor_list[name]
|
|
324
348
|
if '_params_list' in self.__dict__:
|
|
325
349
|
params_list = self.__dict__['_params_list']
|
|
326
350
|
if name in params_list:
|
|
327
|
-
|
|
328
|
-
cast_list = list()
|
|
329
|
-
for para in para_list:
|
|
330
|
-
cast_list.append(self.cast_param(para))
|
|
331
|
-
para_list = ParameterTuple(cast_list)
|
|
332
|
-
return para_list
|
|
351
|
+
return ParameterTuple(params_list[name])
|
|
333
352
|
raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
|
334
353
|
|
|
335
354
|
def __del__(self):
|
|
336
|
-
if context.get_context is not None and context._get_mode() == context.PYNATIVE_MODE:
|
|
337
|
-
_pynative_executor.del_cell(self)
|
|
338
|
-
|
|
339
355
|
# while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
|
|
340
356
|
# here using pop(id(self), None) to avoid KeyError exception
|
|
341
357
|
cells_compile_cache.pop(id(self), None)
|
|
342
|
-
|
|
343
|
-
|
|
358
|
+
try:
|
|
359
|
+
if self.compile_cache:
|
|
360
|
+
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
361
|
+
except AttributeError as e:
|
|
362
|
+
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
363
|
+
f"Please use 'super().__init__()'.") from e
|
|
344
364
|
|
|
345
365
|
def __delattr__(self, name):
|
|
346
366
|
if name in self._params:
|
|
@@ -391,7 +411,7 @@ class Cell(Cell_):
|
|
|
391
411
|
def _do_parameter_broadcast(self):
|
|
392
412
|
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
|
|
393
413
|
if not self.parameter_broadcast_done:
|
|
394
|
-
_pynative_executor.parameter_broadcast(self, self.phase
|
|
414
|
+
_pynative_executor.parameter_broadcast(self, self.phase)
|
|
395
415
|
self.parameter_broadcast_done = True
|
|
396
416
|
|
|
397
417
|
def run_construct(self, cast_inputs, kwargs):
|
|
@@ -427,39 +447,51 @@ class Cell(Cell_):
|
|
|
427
447
|
output = self._run_forward_hook(cast_inputs, output)
|
|
428
448
|
return output
|
|
429
449
|
|
|
430
|
-
def _check_construct_args(self, *
|
|
450
|
+
def _check_construct_args(self, *args):
|
|
431
451
|
"""Check the args needed by the function construct"""
|
|
432
|
-
if kwargs:
|
|
433
|
-
raise ValueError(f"For 'Cell', expect no kwargs here, maybe you pass wrong arguments, "
|
|
434
|
-
f"or there is a key in kwargs that is not used as a function argument. "
|
|
435
|
-
f"args: {inputs}, kwargs: {kwargs}")
|
|
436
452
|
positional_args = 0
|
|
437
453
|
default_args = 0
|
|
454
|
+
has_var = False
|
|
438
455
|
for value in inspect.signature(self.construct).parameters.values():
|
|
439
456
|
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
|
|
440
|
-
|
|
457
|
+
has_var = True
|
|
441
458
|
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
|
442
459
|
if value.default is inspect.Parameter.empty:
|
|
443
460
|
positional_args += 1
|
|
444
461
|
else:
|
|
445
462
|
default_args += 1
|
|
446
463
|
|
|
447
|
-
if
|
|
464
|
+
if has_var:
|
|
465
|
+
return
|
|
466
|
+
|
|
467
|
+
if len(args) < positional_args:
|
|
448
468
|
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, "
|
|
449
|
-
f"but got {len(
|
|
469
|
+
f"but got {len(args)}. When using set_inputs, please make sure that all networks "
|
|
450
470
|
f"and loss functions are configured with set_inputs.")
|
|
451
471
|
|
|
452
|
-
if len(
|
|
472
|
+
if len(args) > positional_args + default_args:
|
|
473
|
+
construct_inputs_names = self.construct.__code__.co_varnames
|
|
474
|
+
if 'self' not in construct_inputs_names:
|
|
475
|
+
raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
|
|
476
|
+
|
|
453
477
|
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
|
|
454
478
|
f"{default_args} default argument, total {positional_args + default_args}, "
|
|
455
|
-
f"but got {len(
|
|
479
|
+
f"but got {len(args)}.")
|
|
456
480
|
|
|
457
481
|
def _hook_fn_registered(self):
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
if
|
|
482
|
+
'''Hook function in graph mode'''
|
|
483
|
+
# Check super().__init__() in graph mode.
|
|
484
|
+
try:
|
|
485
|
+
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
|
|
462
486
|
return True
|
|
487
|
+
except AttributeError as e:
|
|
488
|
+
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
489
|
+
f"Please use 'super().__init__()'.") from e
|
|
490
|
+
if not self._is_recursion_hook:
|
|
491
|
+
self._is_recursion_hook = True
|
|
492
|
+
for cell in self.cells():
|
|
493
|
+
if cell._hook_fn_registered():
|
|
494
|
+
return True
|
|
463
495
|
return False
|
|
464
496
|
|
|
465
497
|
def _get_prims_recursively(self):
|
|
@@ -494,7 +526,7 @@ class Cell(Cell_):
|
|
|
494
526
|
for prim in all_prims:
|
|
495
527
|
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
|
496
528
|
|
|
497
|
-
def shard(self, in_strategy, out_strategy, parameter_plan=None, device="Ascend", level=0):
|
|
529
|
+
def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
498
530
|
"""
|
|
499
531
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
500
532
|
generated by sharding propagation. In PyNative mode, use this method
|
|
@@ -508,11 +540,13 @@ class Cell(Cell_):
|
|
|
508
540
|
Note:
|
|
509
541
|
Only effective in PYNATIVE_MODE and in either ParallelMode.AUTO_PARALLEL with
|
|
510
542
|
search_mode in auto_parallel_context set as sharding_propagation.
|
|
543
|
+
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
511
544
|
|
|
512
545
|
Args:
|
|
513
546
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
|
|
514
547
|
defines the layout of the corresponding input and None represents a data parallel strategy.
|
|
515
|
-
out_strategy (tuple): Define the layout of outputs similar with in_strategy.
|
|
548
|
+
out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
|
|
549
|
+
It is not in use right now. Default: None.
|
|
516
550
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
517
551
|
defines the layout of the parameter like "param_name: layout".
|
|
518
552
|
The key is a parameter name of type 'str'.
|
|
@@ -552,7 +586,11 @@ class Cell(Cell_):
|
|
|
552
586
|
... x = self.block2(x)
|
|
553
587
|
... return x
|
|
554
588
|
"""
|
|
555
|
-
|
|
589
|
+
if context.get_context("mode") != context.PYNATIVE_MODE or \
|
|
590
|
+
context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
|
|
591
|
+
raise AssertionError(f"Cell shard only supports auto parallel under PyNative mode. "
|
|
592
|
+
f"Please check if you call Cell.shard in the script.")
|
|
593
|
+
|
|
556
594
|
shard_fn = Shard()
|
|
557
595
|
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
558
596
|
object.__setattr__(self, "_shard_fn", fn)
|
|
@@ -568,6 +606,8 @@ class Cell(Cell_):
|
|
|
568
606
|
Returns:
|
|
569
607
|
Tuple, the inputs after data type cast.
|
|
570
608
|
"""
|
|
609
|
+
msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
|
|
610
|
+
logger.warning(msg)
|
|
571
611
|
cast_inputs = inputs
|
|
572
612
|
mixed_type = self.get_mixed_precision_type()
|
|
573
613
|
if mixed_type == MixedPrecisionType.FP16:
|
|
@@ -577,32 +617,10 @@ class Cell(Cell_):
|
|
|
577
617
|
|
|
578
618
|
return cast_inputs
|
|
579
619
|
|
|
580
|
-
def _check_args(self, args):
|
|
581
|
-
"""Check the input args's type"""
|
|
582
|
-
index = 1
|
|
583
|
-
for item in args:
|
|
584
|
-
if isinstance(item, Tensor) and item.has_init:
|
|
585
|
-
item.init_data()
|
|
586
|
-
elif isinstance(item, numpy.ndarray):
|
|
587
|
-
suffix = "th"
|
|
588
|
-
if index == 1:
|
|
589
|
-
suffix = "st"
|
|
590
|
-
elif index == 2:
|
|
591
|
-
suffix = "nd"
|
|
592
|
-
elif index == 3:
|
|
593
|
-
suffix = "rd"
|
|
594
|
-
|
|
595
|
-
input_index = str(index) + suffix
|
|
596
|
-
raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
|
|
597
|
-
f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
|
|
598
|
-
f"), and tuple or list containing only these types, and dict whose values are these "
|
|
599
|
-
f"types, but the {input_index} arg type is {type(item)}.")
|
|
600
|
-
index += 1
|
|
601
|
-
|
|
602
620
|
def __call__(self, *args, **kwargs):
|
|
603
621
|
if self.__class__.construct is Cell.construct:
|
|
604
|
-
|
|
605
|
-
|
|
622
|
+
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
|
|
623
|
+
|
|
606
624
|
if kwargs:
|
|
607
625
|
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
|
|
608
626
|
bound_arguments.apply_defaults()
|
|
@@ -610,34 +628,33 @@ class Cell(Cell_):
|
|
|
610
628
|
kwargs = bound_arguments.kwargs
|
|
611
629
|
|
|
612
630
|
# Run in Graph mode.
|
|
613
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
614
|
-
self._check_construct_args(*args
|
|
631
|
+
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
632
|
+
self._check_construct_args(*args)
|
|
615
633
|
if self._hook_fn_registered():
|
|
616
634
|
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
617
635
|
f"function, please use context.set_context to set pynative mode.")
|
|
618
|
-
out = self.compile_and_run(*args)
|
|
636
|
+
out = self.compile_and_run(*args, **kwargs)
|
|
619
637
|
return out
|
|
620
638
|
|
|
621
639
|
# Run in PyNative mode.
|
|
622
640
|
if _pynative_executor.is_first_cell():
|
|
623
|
-
_pynative_executor.set_lazy_build(True)
|
|
624
641
|
_pynative_executor._optimizer = getattr(self, "optimizer", None)
|
|
625
642
|
_pynative_executor._top_cell = self
|
|
626
|
-
# There many Casts in parameter_broadcast. Enable
|
|
643
|
+
# There many Casts in parameter_broadcast. Enable build faster.
|
|
627
644
|
self._do_parameter_broadcast()
|
|
628
645
|
|
|
629
|
-
|
|
646
|
+
_check_args(args)
|
|
647
|
+
self._check_cell_flags_in_pynative()
|
|
630
648
|
|
|
631
649
|
if self.requires_grad:
|
|
632
650
|
_pynative_executor.set_grad_flag(True)
|
|
633
651
|
|
|
634
652
|
if self._dynamic_shape_inputs is not None:
|
|
635
|
-
self._check_compile_dynamic_shape(
|
|
653
|
+
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
636
654
|
|
|
637
655
|
try:
|
|
638
656
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
639
|
-
|
|
640
|
-
output = self._run_construct(cast_inputs, kwargs)
|
|
657
|
+
output = self._run_construct(args, kwargs)
|
|
641
658
|
_pynative_executor.end_graph(self, output, *args, **kwargs)
|
|
642
659
|
except Exception as err:
|
|
643
660
|
_pynative_executor.clear_res()
|
|
@@ -647,6 +664,12 @@ class Cell(Cell_):
|
|
|
647
664
|
output = output.data
|
|
648
665
|
return output
|
|
649
666
|
|
|
667
|
+
def _check_cell_flags_in_pynative(self):
|
|
668
|
+
"""Check the flags added to cell in pynative mode"""
|
|
669
|
+
if hasattr(self, "_func_graph_flags") and self._func_graph_flags.get("output_no_recompute"):
|
|
670
|
+
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
671
|
+
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
672
|
+
|
|
650
673
|
def _add_attr(self, name, value):
|
|
651
674
|
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
|
652
675
|
super(Cell, self)._add_attr(name, value)
|
|
@@ -829,84 +852,19 @@ class Cell(Cell_):
|
|
|
829
852
|
"""
|
|
830
853
|
Replace parameters with sliced tensors by parallel strategies.
|
|
831
854
|
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
Args:
|
|
835
|
-
params (dict): The parameters dictionary used for initializing the data graph.
|
|
836
|
-
"""
|
|
837
|
-
if params is None:
|
|
838
|
-
params = self.parameters_dict()
|
|
839
|
-
if isinstance(params, OrderedDict):
|
|
840
|
-
for key in params:
|
|
841
|
-
tensor = params[key].data
|
|
842
|
-
if key not in self.parameter_layout_dict:
|
|
843
|
-
logger.info("The layout dict does not contain the key %s.", key)
|
|
844
|
-
continue
|
|
845
|
-
if params[key].sliced:
|
|
846
|
-
logger.debug("The param %s is already sliced.", key)
|
|
847
|
-
continue
|
|
848
|
-
layout = self.parameter_layout_dict[key]
|
|
849
|
-
new_tensor = _load_tensor_by_layout(tensor, layout)
|
|
850
|
-
params[key].set_data(new_tensor, True)
|
|
851
|
-
else:
|
|
852
|
-
raise TypeError("For 'load_parameter_slice', the argument 'params' must be OrderedDict type, "
|
|
853
|
-
"but got {}.".format(type(params)))
|
|
854
|
-
|
|
855
|
-
def _load_inputs(self, *inputs):
|
|
855
|
+
Note:
|
|
856
|
+
This interface is deprecated.
|
|
856
857
|
"""
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
Args:
|
|
860
|
-
inputs (Function or Cell): inputs of construct method.
|
|
861
|
-
"""
|
|
862
|
-
parallel_inputs_run = []
|
|
863
|
-
# judge if *args exists in input
|
|
864
|
-
if self.argspec[1] is not None:
|
|
865
|
-
prefix = self.argspec[1]
|
|
866
|
-
for i in range(len(inputs)):
|
|
867
|
-
key = prefix + str(i)
|
|
868
|
-
self._construct_inputs_names = self._construct_inputs_names + (key,)
|
|
869
|
-
self._construct_inputs_num = self._construct_inputs_num + 1
|
|
870
|
-
for i, tensor in enumerate(inputs):
|
|
871
|
-
key = self._construct_inputs_names[i]
|
|
872
|
-
# if input is not used, self.parameter_layout_dict may not contain the key
|
|
873
|
-
if key not in self.parameter_layout_dict:
|
|
874
|
-
logger.warning("Layout dict does not contain the key %s.", key)
|
|
875
|
-
parallel_inputs_run.append(tensor)
|
|
876
|
-
else:
|
|
877
|
-
layout = self.parameter_layout_dict[key]
|
|
878
|
-
new_tensor = _load_tensor_by_layout(tensor, layout)
|
|
879
|
-
parallel_inputs_run.append(new_tensor)
|
|
880
|
-
return tuple(parallel_inputs_run)
|
|
858
|
+
logger.warning("'load_parameter_slice' function is deprecated.")
|
|
881
859
|
|
|
882
860
|
def set_parallel_input_with_inputs(self, *inputs):
|
|
883
861
|
"""
|
|
884
862
|
Slice inputs tensors by parallel strategies.
|
|
885
863
|
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
"""
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
def _get_construct_inputs_number_and_name(self):
|
|
892
|
-
"""Compute self._construct_inputs_names and self._construct_inputs_num"""
|
|
893
|
-
from mindspore._extends.parse.parser import get_parse_method_of_class
|
|
894
|
-
|
|
895
|
-
fn = get_parse_method_of_class(self)
|
|
896
|
-
self.argspec = inspect.getfullargspec(fn)
|
|
897
|
-
self._construct_inputs_num = fn.__code__.co_argcount
|
|
898
|
-
self._construct_inputs_names = fn.__code__.co_varnames
|
|
899
|
-
|
|
900
|
-
if self._construct_inputs_num <= 0:
|
|
901
|
-
raise ValueError(f"For 'set_auto_parallel', the number of inputs must be greater than 0,"
|
|
902
|
-
f"but got {self._construct_inputs_num}.")
|
|
903
|
-
if self._construct_inputs_names[0] != 'self':
|
|
904
|
-
raise ValueError(f"First member of fn function must be self, but got {self._construct_inputs_names[0]}")
|
|
905
|
-
if self._construct_inputs_num - 1 > len(self._construct_inputs_names):
|
|
906
|
-
raise ValueError(f"Num of inputs must be greater than num of fn function members, num of inputs is \
|
|
907
|
-
{self._construct_inputs_names - 1}, num of fn function members is {len(self._construct_inputs_names)}")
|
|
908
|
-
self._construct_inputs_names = self._construct_inputs_names[1:self._construct_inputs_num]
|
|
909
|
-
self._construct_inputs_num = self._construct_inputs_num - 1
|
|
864
|
+
Note:
|
|
865
|
+
This interface is deprecated.
|
|
866
|
+
"""
|
|
867
|
+
logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
|
|
910
868
|
|
|
911
869
|
def set_inputs(self, *inputs):
|
|
912
870
|
"""
|
|
@@ -917,8 +875,8 @@ class Cell(Cell_):
|
|
|
917
875
|
Args:
|
|
918
876
|
inputs (tuple): Inputs of the Cell object.
|
|
919
877
|
|
|
920
|
-
|
|
921
|
-
This is an experimental
|
|
878
|
+
.. warning::
|
|
879
|
+
This is an experimental API that is subject to change or deletion.
|
|
922
880
|
|
|
923
881
|
Examples:
|
|
924
882
|
>>> import numpy as np
|
|
@@ -949,7 +907,7 @@ class Cell(Cell_):
|
|
|
949
907
|
if self._dynamic_shape_inputs:
|
|
950
908
|
ds.config.set_dynamic_shape(True)
|
|
951
909
|
if context._get_mode() == context.PYNATIVE_MODE:
|
|
952
|
-
_pynative_executor.set_dynamic_input(self
|
|
910
|
+
_pynative_executor.set_dynamic_input(self)
|
|
953
911
|
|
|
954
912
|
def get_inputs(self):
|
|
955
913
|
"""
|
|
@@ -958,36 +916,31 @@ class Cell(Cell_):
|
|
|
958
916
|
Returns:
|
|
959
917
|
inputs (tuple), Inputs of the Cell object.
|
|
960
918
|
|
|
961
|
-
|
|
962
|
-
This is an experimental
|
|
919
|
+
.. warning::
|
|
920
|
+
This is an experimental API that is subject to change or deletion.
|
|
963
921
|
"""
|
|
964
922
|
|
|
965
923
|
return self._dynamic_shape_inputs
|
|
966
924
|
|
|
967
|
-
def compile(self, *
|
|
925
|
+
def compile(self, *args, **kwargs):
|
|
968
926
|
"""
|
|
969
927
|
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
|
|
970
928
|
|
|
971
929
|
Args:
|
|
972
|
-
|
|
930
|
+
args (tuple): Args of the Cell object.
|
|
931
|
+
kwargs (dict): Kwargs of the Cell object.
|
|
973
932
|
"""
|
|
974
|
-
if self._dynamic_shape_inputs is None
|
|
975
|
-
_cell_graph_executor.compile(self,
|
|
976
|
-
jit_config_dict=self._jit_config_dict)
|
|
933
|
+
if self._dynamic_shape_inputs is None:
|
|
934
|
+
_cell_graph_executor.compile(self, phase=self.phase,
|
|
935
|
+
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
977
936
|
else:
|
|
978
|
-
self._check_compile_dynamic_shape(
|
|
979
|
-
if self.saved_dynamic_shape:
|
|
980
|
-
for i in range(len(self.saved_dynamic_shape)):
|
|
981
|
-
if self.saved_dynamic_shape[i].shape != self._dynamic_shape_inputs[i].shape:
|
|
982
|
-
return
|
|
983
|
-
|
|
937
|
+
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
984
938
|
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
985
939
|
_cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
|
|
986
|
-
|
|
987
|
-
jit_config_dict=self._jit_config_dict)
|
|
940
|
+
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
988
941
|
logger.debug("Compiled Graph with dynamic shape")
|
|
989
942
|
|
|
990
|
-
def compile_and_run(self, *
|
|
943
|
+
def compile_and_run(self, *args, **kwargs):
|
|
991
944
|
"""
|
|
992
945
|
Compile and run Cell, the input must be consistent with the input defined in construct.
|
|
993
946
|
|
|
@@ -995,25 +948,25 @@ class Cell(Cell_):
|
|
|
995
948
|
It is not recommended to call directly.
|
|
996
949
|
|
|
997
950
|
Args:
|
|
998
|
-
|
|
951
|
+
args (tuple): Args of the Cell object.
|
|
952
|
+
kwargs (dict): Kwargs of the Cell object.
|
|
999
953
|
|
|
1000
954
|
Returns:
|
|
1001
955
|
Object, the result of executing.
|
|
1002
956
|
"""
|
|
1003
|
-
self.
|
|
1004
|
-
self.compile(*inputs)
|
|
957
|
+
self.compile(*args, **kwargs)
|
|
1005
958
|
|
|
1006
|
-
|
|
1007
|
-
return _cell_graph_executor(self, *
|
|
959
|
+
new_args = _get_args_for_run(self, args, kwargs)
|
|
960
|
+
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
1008
961
|
|
|
1009
962
|
def auto_parallel_compile_and_run(self):
|
|
1010
963
|
"""
|
|
1011
964
|
Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
|
|
1012
965
|
|
|
1013
|
-
|
|
1014
|
-
|
|
966
|
+
Note:
|
|
967
|
+
This interface is deprecated.
|
|
1015
968
|
"""
|
|
1016
|
-
|
|
969
|
+
logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
|
|
1017
970
|
|
|
1018
971
|
def exec_checkpoint_graph(self):
|
|
1019
972
|
"""Executes saving checkpoint graph operation."""
|
|
@@ -1063,6 +1016,8 @@ class Cell(Cell_):
|
|
|
1063
1016
|
Returns:
|
|
1064
1017
|
Parameter, the input parameter with type automatically cast.
|
|
1065
1018
|
"""
|
|
1019
|
+
msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
|
|
1020
|
+
logger.warning(msg)
|
|
1066
1021
|
mixed_type = self.get_mixed_precision_type()
|
|
1067
1022
|
if mixed_type != MixedPrecisionType.NOTSET:
|
|
1068
1023
|
if mixed_type == MixedPrecisionType.FP32:
|
|
@@ -1084,8 +1039,12 @@ class Cell(Cell_):
|
|
|
1084
1039
|
|
|
1085
1040
|
Raises:
|
|
1086
1041
|
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
|
1042
|
+
TypeError: If type of `child_name` is not str.
|
|
1087
1043
|
TypeError: Child Cell's type is incorrect.
|
|
1088
1044
|
"""
|
|
1045
|
+
if not isinstance(child_name, str):
|
|
1046
|
+
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
1047
|
+
f"but got {type(child_name)}.")
|
|
1089
1048
|
if not child_name or '.' in child_name:
|
|
1090
1049
|
raise KeyError("For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1091
1050
|
"can not contain '.'")
|
|
@@ -1097,7 +1056,7 @@ class Cell(Cell_):
|
|
|
1097
1056
|
f"but got type {type(child_cell)}.")
|
|
1098
1057
|
self._cells[child_name] = child_cell
|
|
1099
1058
|
|
|
1100
|
-
def construct(self, *
|
|
1059
|
+
def construct(self, *args, **kwargs):
|
|
1101
1060
|
"""
|
|
1102
1061
|
Defines the computation to be performed. This method must be overridden by all subclasses.
|
|
1103
1062
|
|
|
@@ -1105,7 +1064,7 @@ class Cell(Cell_):
|
|
|
1105
1064
|
It is not supported currently that inputs contain both tuple and non-tuple types at same time.
|
|
1106
1065
|
|
|
1107
1066
|
Args:
|
|
1108
|
-
|
|
1067
|
+
args (tuple): Tuple of variable parameters.
|
|
1109
1068
|
kwargs (dict): Dictionary of variable keyword parameters.
|
|
1110
1069
|
|
|
1111
1070
|
Returns:
|
|
@@ -1158,15 +1117,7 @@ class Cell(Cell_):
|
|
|
1158
1117
|
def _updata(param):
|
|
1159
1118
|
if param in replace:
|
|
1160
1119
|
return replace.get(param)
|
|
1161
|
-
|
|
1162
|
-
set_sliced = False
|
|
1163
|
-
if auto_parallel_mode:
|
|
1164
|
-
set_sliced = True
|
|
1165
|
-
if param.name not in self.parameter_layout_dict:
|
|
1166
|
-
logger.debug("Layout dict does not contain the key %s.", param.name)
|
|
1167
|
-
else:
|
|
1168
|
-
layout = self.parameter_layout_dict[param.name]
|
|
1169
|
-
new_p = param.init_data(layout, set_sliced=set_sliced)
|
|
1120
|
+
new_p = param.init_data(None, set_sliced=False)
|
|
1170
1121
|
replace[param] = new_p
|
|
1171
1122
|
return new_p
|
|
1172
1123
|
|
|
@@ -1265,6 +1216,7 @@ class Cell(Cell_):
|
|
|
1265
1216
|
param.is_init = False
|
|
1266
1217
|
param.name = prefix + name
|
|
1267
1218
|
|
|
1219
|
+
@jit_forbidden_register
|
|
1268
1220
|
def trainable_params(self, recurse=True):
|
|
1269
1221
|
"""
|
|
1270
1222
|
Returns all trainable parameters.
|
|
@@ -1279,6 +1231,7 @@ class Cell(Cell_):
|
|
|
1279
1231
|
"""
|
|
1280
1232
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1281
1233
|
|
|
1234
|
+
@jit_forbidden_register
|
|
1282
1235
|
def untrainable_params(self, recurse=True):
|
|
1283
1236
|
"""
|
|
1284
1237
|
Returns all untrainable parameters.
|
|
@@ -1293,6 +1246,7 @@ class Cell(Cell_):
|
|
|
1293
1246
|
"""
|
|
1294
1247
|
return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1295
1248
|
|
|
1249
|
+
@jit_forbidden_register
|
|
1296
1250
|
def get_parameters(self, expand=True):
|
|
1297
1251
|
"""
|
|
1298
1252
|
Returns an iterator over cell parameters.
|
|
@@ -1484,6 +1438,38 @@ class Cell(Cell_):
|
|
|
1484
1438
|
if "fp32" in flags and flags.get("fp32", False):
|
|
1485
1439
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1486
1440
|
|
|
1441
|
+
def apply(self, fn):
|
|
1442
|
+
"""
|
|
1443
|
+
Applies fn recursively to every subcell (as returned by .cells()) as well as self.
|
|
1444
|
+
Typical use includes initializing the parameters of a model.
|
|
1445
|
+
|
|
1446
|
+
Args:
|
|
1447
|
+
fn (function): function to be applied to each subcell.
|
|
1448
|
+
|
|
1449
|
+
Returns:
|
|
1450
|
+
Cell, self.
|
|
1451
|
+
|
|
1452
|
+
Examples:
|
|
1453
|
+
>>> import mindspore.nn as nn
|
|
1454
|
+
>>> from mindspore.common.initializer import initializer, One
|
|
1455
|
+
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
|
1456
|
+
>>> def func(cell):
|
|
1457
|
+
... if isinstance(cell, nn.Dense):
|
|
1458
|
+
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
|
1459
|
+
>>> net.apply(func)
|
|
1460
|
+
SequentialCell<
|
|
1461
|
+
(0): Dense<input_channels=2, output_channels=2, has_bias=True>
|
|
1462
|
+
(1): Dense<input_channels=2, output_channels=2, has_bias=True>
|
|
1463
|
+
>
|
|
1464
|
+
>>> print(net[0].weight.asnumpy())
|
|
1465
|
+
[[1. 1.]
|
|
1466
|
+
[1. 1.]]
|
|
1467
|
+
"""
|
|
1468
|
+
for cell in self.cells():
|
|
1469
|
+
cell.apply(fn)
|
|
1470
|
+
fn(self)
|
|
1471
|
+
return self
|
|
1472
|
+
|
|
1487
1473
|
def add_flags(self, **flags):
|
|
1488
1474
|
"""
|
|
1489
1475
|
Add customized attributes for cell.
|
|
@@ -1538,7 +1524,7 @@ class Cell(Cell_):
|
|
|
1538
1524
|
Add cast on all inputs of cell and child cells to run with certain float type.
|
|
1539
1525
|
|
|
1540
1526
|
If `dst_type` is `mindspore.dtype.float16`, all the inputs of Cell, including input, Parameter and Tensor, will
|
|
1541
|
-
be cast to float16. Please refer to the usage in source code of :func:`mindspore.build_train_network`.
|
|
1527
|
+
be cast to float16. Please refer to the usage in source code of :func:`mindspore.amp.build_train_network`.
|
|
1542
1528
|
|
|
1543
1529
|
Note:
|
|
1544
1530
|
Multiple calls will overwrite.
|
|
@@ -1554,7 +1540,7 @@ class Cell(Cell_):
|
|
|
1554
1540
|
ValueError: If dst_type is not mstype.float32 or mstype.float16.
|
|
1555
1541
|
|
|
1556
1542
|
Supported Platforms:
|
|
1557
|
-
|
|
1543
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1558
1544
|
|
|
1559
1545
|
Examples:
|
|
1560
1546
|
>>> import mindspore.nn as nn
|
|
@@ -1570,8 +1556,10 @@ class Cell(Cell_):
|
|
|
1570
1556
|
"but got {}.".format(dst_type))
|
|
1571
1557
|
if dst_type == mstype.float16:
|
|
1572
1558
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
|
1559
|
+
self.to_float_fp16 = True
|
|
1573
1560
|
else:
|
|
1574
1561
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1562
|
+
self.to_float_fp16 = False
|
|
1575
1563
|
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
|
|
1576
1564
|
self._add_init_args(**flags)
|
|
1577
1565
|
return self
|
|
@@ -1582,7 +1570,7 @@ class Cell(Cell_):
|
|
|
1582
1570
|
accelerate the algorithm in the algorithm library.
|
|
1583
1571
|
|
|
1584
1572
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1585
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/
|
|
1573
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/python/mindspore/boost>`_.
|
|
1586
1574
|
|
|
1587
1575
|
Note:
|
|
1588
1576
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1627,6 +1615,10 @@ class Cell(Cell_):
|
|
|
1627
1615
|
for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
|
|
1628
1616
|
set to true, the training branch will be executed, otherwise another branch.
|
|
1629
1617
|
|
|
1618
|
+
Note:
|
|
1619
|
+
When execute function Model.train(), framework will call Cell.set_train(True).
|
|
1620
|
+
When execute function Model.eval(), framework will call Cell.set_train(False).
|
|
1621
|
+
|
|
1630
1622
|
Args:
|
|
1631
1623
|
mode (bool): Specifies whether the model is training. Default: True.
|
|
1632
1624
|
|
|
@@ -1655,11 +1647,9 @@ class Cell(Cell_):
|
|
|
1655
1647
|
Set the cell to auto parallel mode.
|
|
1656
1648
|
|
|
1657
1649
|
Note:
|
|
1658
|
-
|
|
1659
|
-
this interface needs to be called by the cell.
|
|
1650
|
+
This interface is deprecated.
|
|
1660
1651
|
"""
|
|
1661
|
-
|
|
1662
|
-
self._get_construct_inputs_number_and_name()
|
|
1652
|
+
logger.warning("'set_auto_parallel' function is deprecated.")
|
|
1663
1653
|
|
|
1664
1654
|
def set_jit_config(self, jit_config):
|
|
1665
1655
|
"""
|
|
@@ -1672,6 +1662,11 @@ class Cell(Cell_):
|
|
|
1672
1662
|
logger.warning("For Cell, jit config can only be set once, ignore this setting.")
|
|
1673
1663
|
else:
|
|
1674
1664
|
self._jit_config_dict = jit_config.jit_config_dict
|
|
1665
|
+
enable_ge = os.getenv("MS_ENABLE_GE") == '1'
|
|
1666
|
+
enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
|
|
1667
|
+
if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
|
|
1668
|
+
raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
|
|
1669
|
+
format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
|
1675
1670
|
|
|
1676
1671
|
def flatten_weights(self, fusion_size=0):
|
|
1677
1672
|
"""
|
|
@@ -1695,7 +1690,7 @@ class Cell(Cell_):
|
|
|
1695
1690
|
Register forward pre hook function for Cell object.
|
|
1696
1691
|
|
|
1697
1692
|
Note:
|
|
1698
|
-
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or
|
|
1693
|
+
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
1699
1694
|
- 'hook_fn' must be defined as the following code.
|
|
1700
1695
|
`cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
|
|
1701
1696
|
input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
|
|
@@ -1758,7 +1753,7 @@ class Cell(Cell_):
|
|
|
1758
1753
|
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
1759
1754
|
f"function, but got {type(hook_fn)}.")
|
|
1760
1755
|
if hook_fn.__code__.co_name == "staging_specialize":
|
|
1761
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@
|
|
1756
|
+
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
1762
1757
|
|
|
1763
1758
|
self._enable_forward_pre_hook = True
|
|
1764
1759
|
_pynative_executor.set_hook_changed(self)
|
|
@@ -1797,7 +1792,7 @@ class Cell(Cell_):
|
|
|
1797
1792
|
Set the Cell forward hook function.
|
|
1798
1793
|
|
|
1799
1794
|
Note:
|
|
1800
|
-
- The `register_forward_hook(hook_fn)` does not work in graph mode or
|
|
1795
|
+
- The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
1801
1796
|
- 'hook_fn' must be defined as the following code.
|
|
1802
1797
|
`cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
|
|
1803
1798
|
input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
|
|
@@ -1862,7 +1857,7 @@ class Cell(Cell_):
|
|
|
1862
1857
|
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
1863
1858
|
f"function, but got {type(hook_fn)}.")
|
|
1864
1859
|
if hook_fn.__code__.co_name == "staging_specialize":
|
|
1865
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@
|
|
1860
|
+
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
1866
1861
|
|
|
1867
1862
|
self._enable_forward_hook = True
|
|
1868
1863
|
_pynative_executor.set_hook_changed(self)
|
|
@@ -1899,7 +1894,7 @@ class Cell(Cell_):
|
|
|
1899
1894
|
Register the backward hook function.
|
|
1900
1895
|
|
|
1901
1896
|
Note:
|
|
1902
|
-
- The `register_backward_hook(hook_fn)` does not work in graph mode or
|
|
1897
|
+
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
1903
1898
|
- The 'hook_fn' must be defined as the following code.
|
|
1904
1899
|
`cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
|
|
1905
1900
|
gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
|
|
@@ -2002,6 +1997,7 @@ class Cell(Cell_):
|
|
|
2002
1997
|
|
|
2003
1998
|
Note:
|
|
2004
1999
|
It only works when a running task is in the parameter server mode.
|
|
2000
|
+
It is only supported in graph mode.
|
|
2005
2001
|
|
|
2006
2002
|
Args:
|
|
2007
2003
|
recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
|
|
@@ -2083,9 +2079,6 @@ class Cell(Cell_):
|
|
|
2083
2079
|
"""
|
|
2084
2080
|
Set the cell recomputed.
|
|
2085
2081
|
"""
|
|
2086
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
2087
|
-
raise TypeError("Recompute is not supported in pynative mode currently, you can use "
|
|
2088
|
-
"'context.set_context(mode=context.GRAPH_MODE)' to set graph mode.")
|
|
2089
2082
|
Validator.check_bool(mode)
|
|
2090
2083
|
Validator.check_bool(output_recompute)
|
|
2091
2084
|
if not self._has_config_recompute:
|
|
@@ -2184,35 +2177,86 @@ class Cell(Cell_):
|
|
|
2184
2177
|
params.append(param)
|
|
2185
2178
|
return params
|
|
2186
2179
|
|
|
2187
|
-
def
|
|
2180
|
+
def place(self, role, rank_id):
|
|
2188
2181
|
"""
|
|
2189
|
-
|
|
2182
|
+
Set the label for all operators in this cell.
|
|
2183
|
+
This label tells MindSpore compiler on which process this cell should be launched.
|
|
2184
|
+
And each process's identical label consists of input `role` and `rank_id`.
|
|
2185
|
+
So by setting different cells with different labels, which will be launched on different processes,
|
|
2186
|
+
users can launch a distributed training or predicting job.
|
|
2187
|
+
|
|
2188
|
+
Note:
|
|
2189
|
+
- This method is effective only after
|
|
2190
|
+
`mindspore.communication.init()` is called for dynamic cluster building.
|
|
2190
2191
|
|
|
2191
2192
|
Args:
|
|
2192
|
-
|
|
2193
|
+
role (str): The role of the process on which this cell will be launched.
|
|
2194
|
+
Only 'MS_WORKER' is supported for now.
|
|
2195
|
+
rank_id (int): The rank id of the process on which this cell will be launched.
|
|
2196
|
+
The rank is unique in processes with the same role.
|
|
2197
|
+
|
|
2198
|
+
Examples:
|
|
2199
|
+
>>> from mindspore import context
|
|
2200
|
+
>>> import mindspore.nn as nn
|
|
2201
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
2202
|
+
>>> fc = nn.Dense(2, 3)
|
|
2203
|
+
>>> fc.place('MS_WORKER', 0)
|
|
2204
|
+
"""
|
|
2205
|
+
all_ops = self._get_prims_recursively()
|
|
2206
|
+
for op in all_ops:
|
|
2207
|
+
op.place(role, rank_id)
|
|
2208
|
+
|
|
2209
|
+
def _check_dynamic_tensor(self, set_input, net_input, index):
|
|
2193
2210
|
"""
|
|
2194
|
-
|
|
2195
|
-
|
|
2196
|
-
|
|
2197
|
-
|
|
2198
|
-
|
|
2199
|
-
|
|
2211
|
+
Check if tensor is correctly set for dynamic shape.
|
|
2212
|
+
|
|
2213
|
+
Args:
|
|
2214
|
+
set_input (Tensor): Tensor set for dynamic shape.
|
|
2215
|
+
net_input (Tensor): Input tensor of the Cell object.
|
|
2216
|
+
index (int): Tensor index for set inputs.
|
|
2217
|
+
"""
|
|
2218
|
+
if not isinstance(net_input, Tensor):
|
|
2219
|
+
raise TypeError(
|
|
2220
|
+
f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
|
|
2221
|
+
if set_input.dtype != net_input.dtype:
|
|
2222
|
+
raise ValueError(
|
|
2223
|
+
f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
|
|
2224
|
+
f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
|
2225
|
+
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
|
2226
|
+
raise ValueError(
|
|
2227
|
+
f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
|
|
2228
|
+
f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
|
|
2229
|
+
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
|
2230
|
+
raise ValueError(
|
|
2231
|
+
f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
|
|
2232
|
+
f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.")
|
|
2233
|
+
|
|
2234
|
+
def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
|
|
2235
|
+
"""
|
|
2236
|
+
Check if graph has been compiled with dynamic shape.
|
|
2237
|
+
|
|
2238
|
+
Args:
|
|
2239
|
+
net_inputs (tuple): Inputs of the Cell object.
|
|
2240
|
+
"""
|
|
2241
|
+
set_inputs_len = len(set_inputs)
|
|
2242
|
+
net_inputs_len = len(net_inputs)
|
|
2243
|
+
if set_inputs_len != net_inputs_len:
|
|
2244
|
+
raise ValueError("The length of 'set_inputs' must be equal to network's inputs, "
|
|
2245
|
+
f"but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
|
|
2246
|
+
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
2200
2247
|
if isinstance(set_input, Tensor):
|
|
2201
|
-
|
|
2248
|
+
self._check_dynamic_tensor(set_input, net_input, index)
|
|
2249
|
+
elif isinstance(set_input, (tuple, list)):
|
|
2250
|
+
if not isinstance(net_input, (tuple, list)):
|
|
2202
2251
|
raise TypeError(
|
|
2203
|
-
f"The {index + 1}th input type of 'set_inputs' must be
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
|
2209
|
-
raise ValueError(
|
|
2210
|
-
f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
|
|
2211
|
-
f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
|
|
2212
|
-
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
|
2252
|
+
f"The {index + 1}th input type of 'set_inputs' must be tuple or list, "
|
|
2253
|
+
f"but got {type(net_input)}.")
|
|
2254
|
+
self._check_compile_dynamic_shape(set_input, net_input)
|
|
2255
|
+
else:
|
|
2256
|
+
if net_input != set_input:
|
|
2213
2257
|
raise ValueError(
|
|
2214
|
-
f"The {index + 1}th input
|
|
2215
|
-
f"
|
|
2258
|
+
f"The {index + 1}th input of 'set_inputs' must be the same with network's input, but got "
|
|
2259
|
+
f"set_inputs: {set_input} and network's input: {net_input}.")
|
|
2216
2260
|
|
|
2217
2261
|
|
|
2218
2262
|
class GraphCell(Cell):
|
|
@@ -2228,6 +2272,11 @@ class GraphCell(Cell):
|
|
|
2228
2272
|
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
|
|
2229
2273
|
If the parameter exists in the graph according to the name, update it's value.
|
|
2230
2274
|
If the parameter does not exist, ignore it. Default: None.
|
|
2275
|
+
obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
|
|
2276
|
+
used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
|
|
2277
|
+
a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
|
|
2278
|
+
provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: None.
|
|
2279
|
+
|
|
2231
2280
|
Raises:
|
|
2232
2281
|
TypeError: If the `graph` is not a FuncGraph.
|
|
2233
2282
|
TypeError: If the `params_init` is not a dict.
|
|
@@ -2242,7 +2291,8 @@ class GraphCell(Cell):
|
|
|
2242
2291
|
>>> import mindspore as ms
|
|
2243
2292
|
>>> import mindspore.nn as nn
|
|
2244
2293
|
>>> from mindspore import Tensor
|
|
2245
|
-
>>>
|
|
2294
|
+
>>> from mindspore import context
|
|
2295
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
2246
2296
|
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
|
|
2247
2297
|
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2248
2298
|
>>> ms.export(net, input, file_name="net", file_format="MINDIR")
|
|
@@ -2254,13 +2304,23 @@ class GraphCell(Cell):
|
|
|
2254
2304
|
[6. 9. 6.]
|
|
2255
2305
|
[4. 6. 4.]]]]
|
|
2256
2306
|
"""
|
|
2257
|
-
|
|
2307
|
+
|
|
2308
|
+
def __init__(self, graph, params_init=None, obf_random_seed=None):
|
|
2258
2309
|
super(GraphCell, self).__init__(auto_prefix=True)
|
|
2259
2310
|
if not isinstance(graph, FuncGraph):
|
|
2260
2311
|
raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
|
|
2261
2312
|
f"but got type {type(graph)}.")
|
|
2262
2313
|
self.graph = graph
|
|
2263
|
-
|
|
2314
|
+
self.obf_random_seed = obf_random_seed
|
|
2315
|
+
if obf_random_seed is not None:
|
|
2316
|
+
if not isinstance(obf_random_seed, int):
|
|
2317
|
+
raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
|
|
2318
|
+
int_64_max = 9223372036854775807
|
|
2319
|
+
if obf_random_seed <= 0 or obf_random_seed > int_64_max:
|
|
2320
|
+
raise ValueError(
|
|
2321
|
+
"'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
|
|
2322
|
+
"but got {}.".format(int_64_max, obf_random_seed))
|
|
2323
|
+
self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
|
|
2264
2324
|
params_init = {} if params_init is None else params_init
|
|
2265
2325
|
if not isinstance(params_init, dict):
|
|
2266
2326
|
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
|
|
@@ -2277,10 +2337,13 @@ class GraphCell(Cell):
|
|
|
2277
2337
|
def construct(self, *inputs):
|
|
2278
2338
|
return self.graph(*inputs)
|
|
2279
2339
|
|
|
2280
|
-
def __call__(self, *
|
|
2340
|
+
def __call__(self, *args, **kwargs):
|
|
2281
2341
|
self.phase = "graph_load_from_mindir"
|
|
2282
2342
|
self._add_attr("graph_load_from_mindir", self.graph)
|
|
2283
|
-
|
|
2343
|
+
if not self.obf_random_seed:
|
|
2344
|
+
return self.compile_and_run(*args, **kwargs)
|
|
2345
|
+
append_input = Tensor((numpy.ones((1, 1)) * self._branch_control_input).astype(numpy.int32))
|
|
2346
|
+
return self.compile_and_run(*args, append_input, **kwargs)
|
|
2284
2347
|
|
|
2285
2348
|
|
|
2286
2349
|
def _check_param_list_tuple(value):
|