mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -31,15 +31,40 @@ from mindspore.common.hook_handle import HookHandle
|
|
|
31
31
|
from mindspore.context import ParallelMode
|
|
32
32
|
from mindspore import context
|
|
33
33
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
34
|
-
from mindspore
|
|
34
|
+
from mindspore import _checkparam as Validator
|
|
35
35
|
from mindspore.common import dtype as mstype
|
|
36
36
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
+
from mindspore.common.api import _generate_branch_control_input
|
|
37
38
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
38
39
|
from mindspore.common.tensor import Tensor
|
|
39
40
|
from mindspore.ops.operations import Cast
|
|
40
41
|
from mindspore.ops.primitive import Primitive
|
|
41
42
|
from mindspore.ops.operations import _inner_ops as inner
|
|
42
43
|
from mindspore.parallel.shard import Shard
|
|
44
|
+
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _check_args(args):
|
|
48
|
+
"""Check the input args's type"""
|
|
49
|
+
index = 1
|
|
50
|
+
for item in args:
|
|
51
|
+
if isinstance(item, Tensor) and item.has_init:
|
|
52
|
+
item.init_data()
|
|
53
|
+
elif isinstance(item, numpy.ndarray):
|
|
54
|
+
suffix = "th"
|
|
55
|
+
if index == 1:
|
|
56
|
+
suffix = "st"
|
|
57
|
+
elif index == 2:
|
|
58
|
+
suffix = "nd"
|
|
59
|
+
elif index == 3:
|
|
60
|
+
suffix = "rd"
|
|
61
|
+
|
|
62
|
+
input_index = str(index) + suffix
|
|
63
|
+
raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
|
|
64
|
+
f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
|
|
65
|
+
f"), and tuple or list containing only these types, and dict whose values are these "
|
|
66
|
+
f"types, but the {input_index} arg type is {type(item)}.")
|
|
67
|
+
index += 1
|
|
43
68
|
|
|
44
69
|
|
|
45
70
|
class Cell(Cell_):
|
|
@@ -53,9 +78,12 @@ class Cell(Cell_):
|
|
|
53
78
|
PYNATIVE_MODE (dynamic graph mode).
|
|
54
79
|
|
|
55
80
|
Args:
|
|
56
|
-
auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its
|
|
57
|
-
|
|
58
|
-
|
|
81
|
+
auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
|
|
82
|
+
affects the names of parameters in the `Cell`. If set to True, the parameter name will be
|
|
83
|
+
automatically prefixed, otherwise not. In general, the backbone network should be set to True,
|
|
84
|
+
otherwise the duplicate name problem will appear. The cell to train the backbone network, such as
|
|
85
|
+
optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to False, otherwise the
|
|
86
|
+
parameter name in backbone will be changed by mistake. Default: True.
|
|
59
87
|
flags (dict, optional): Network configuration information, currently it is used for the binding of network
|
|
60
88
|
and dataset. Users can also customize network attributes by this parameter. Default: None.
|
|
61
89
|
|
|
@@ -139,6 +167,7 @@ class Cell(Cell_):
|
|
|
139
167
|
self.saved_dynamic_shape = None
|
|
140
168
|
self._jit_config_dict = dict()
|
|
141
169
|
self.grad_ops_label = False
|
|
170
|
+
self.to_float_fp16 = False
|
|
142
171
|
|
|
143
172
|
def __getstate__(self):
|
|
144
173
|
base = Cell_.__getstate__(self)
|
|
@@ -150,6 +179,9 @@ class Cell(Cell_):
|
|
|
150
179
|
self.__dict__ = dict_
|
|
151
180
|
self._attr_synced = False
|
|
152
181
|
|
|
182
|
+
def __bool__(self):
|
|
183
|
+
return True
|
|
184
|
+
|
|
153
185
|
@property
|
|
154
186
|
def _cell_tag(self):
|
|
155
187
|
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
|
@@ -325,10 +357,10 @@ class Cell(Cell_):
|
|
|
325
357
|
cells_compile_cache.pop(id(self), None)
|
|
326
358
|
try:
|
|
327
359
|
if self.compile_cache:
|
|
328
|
-
_cell_graph_executor.del_net_res(self.compile_cache)
|
|
329
|
-
except AttributeError:
|
|
360
|
+
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
361
|
+
except AttributeError as e:
|
|
330
362
|
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
331
|
-
f"Please use 'super().__init__()'.")
|
|
363
|
+
f"Please use 'super().__init__()'.") from e
|
|
332
364
|
|
|
333
365
|
def __delattr__(self, name):
|
|
334
366
|
if name in self._params:
|
|
@@ -415,46 +447,46 @@ class Cell(Cell_):
|
|
|
415
447
|
output = self._run_forward_hook(cast_inputs, output)
|
|
416
448
|
return output
|
|
417
449
|
|
|
418
|
-
def _check_construct_args(self, *
|
|
450
|
+
def _check_construct_args(self, *args):
|
|
419
451
|
"""Check the args needed by the function construct"""
|
|
420
|
-
if kwargs:
|
|
421
|
-
raise ValueError(f"For 'Cell', expect no kwargs here, maybe you pass wrong arguments, "
|
|
422
|
-
f"or there is a key in kwargs that is not used as a function argument. "
|
|
423
|
-
f"args: {inputs}, kwargs: {kwargs}")
|
|
424
452
|
positional_args = 0
|
|
425
453
|
default_args = 0
|
|
454
|
+
has_var = False
|
|
426
455
|
for value in inspect.signature(self.construct).parameters.values():
|
|
427
456
|
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
|
|
428
|
-
|
|
457
|
+
has_var = True
|
|
429
458
|
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
|
430
459
|
if value.default is inspect.Parameter.empty:
|
|
431
460
|
positional_args += 1
|
|
432
461
|
else:
|
|
433
462
|
default_args += 1
|
|
434
463
|
|
|
435
|
-
if
|
|
464
|
+
if has_var:
|
|
465
|
+
return
|
|
466
|
+
|
|
467
|
+
if len(args) < positional_args:
|
|
436
468
|
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, "
|
|
437
|
-
f"but got {len(
|
|
469
|
+
f"but got {len(args)}. When using set_inputs, please make sure that all networks "
|
|
438
470
|
f"and loss functions are configured with set_inputs.")
|
|
439
471
|
|
|
440
|
-
if len(
|
|
472
|
+
if len(args) > positional_args + default_args:
|
|
441
473
|
construct_inputs_names = self.construct.__code__.co_varnames
|
|
442
474
|
if 'self' not in construct_inputs_names:
|
|
443
475
|
raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
|
|
444
476
|
|
|
445
477
|
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
|
|
446
478
|
f"{default_args} default argument, total {positional_args + default_args}, "
|
|
447
|
-
f"but got {len(
|
|
479
|
+
f"but got {len(args)}.")
|
|
448
480
|
|
|
449
481
|
def _hook_fn_registered(self):
|
|
450
482
|
'''Hook function in graph mode'''
|
|
451
|
-
#Check super().__init__() in graph mode.
|
|
483
|
+
# Check super().__init__() in graph mode.
|
|
452
484
|
try:
|
|
453
485
|
if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
|
|
454
486
|
return True
|
|
455
|
-
except AttributeError:
|
|
487
|
+
except AttributeError as e:
|
|
456
488
|
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
457
|
-
f"Please use 'super().__init__()'.")
|
|
489
|
+
f"Please use 'super().__init__()'.") from e
|
|
458
490
|
if not self._is_recursion_hook:
|
|
459
491
|
self._is_recursion_hook = True
|
|
460
492
|
for cell in self.cells():
|
|
@@ -585,31 +617,9 @@ class Cell(Cell_):
|
|
|
585
617
|
|
|
586
618
|
return cast_inputs
|
|
587
619
|
|
|
588
|
-
def _check_args(self, args):
|
|
589
|
-
"""Check the input args's type"""
|
|
590
|
-
index = 1
|
|
591
|
-
for item in args:
|
|
592
|
-
if isinstance(item, Tensor) and item.has_init:
|
|
593
|
-
item.init_data()
|
|
594
|
-
elif isinstance(item, numpy.ndarray):
|
|
595
|
-
suffix = "th"
|
|
596
|
-
if index == 1:
|
|
597
|
-
suffix = "st"
|
|
598
|
-
elif index == 2:
|
|
599
|
-
suffix = "nd"
|
|
600
|
-
elif index == 3:
|
|
601
|
-
suffix = "rd"
|
|
602
|
-
|
|
603
|
-
input_index = str(index) + suffix
|
|
604
|
-
raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
|
|
605
|
-
f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
|
|
606
|
-
f"), and tuple or list containing only these types, and dict whose values are these "
|
|
607
|
-
f"types, but the {input_index} arg type is {type(item)}.")
|
|
608
|
-
index += 1
|
|
609
|
-
|
|
610
620
|
def __call__(self, *args, **kwargs):
|
|
611
621
|
if self.__class__.construct is Cell.construct:
|
|
612
|
-
raise AttributeError("For 'Cell', the method 'construct' is not defined.
|
|
622
|
+
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
|
|
613
623
|
|
|
614
624
|
if kwargs:
|
|
615
625
|
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
|
|
@@ -619,28 +629,28 @@ class Cell(Cell_):
|
|
|
619
629
|
|
|
620
630
|
# Run in Graph mode.
|
|
621
631
|
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
622
|
-
self._check_construct_args(*args
|
|
632
|
+
self._check_construct_args(*args)
|
|
623
633
|
if self._hook_fn_registered():
|
|
624
634
|
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
625
635
|
f"function, please use context.set_context to set pynative mode.")
|
|
626
|
-
out = self.compile_and_run(*args)
|
|
636
|
+
out = self.compile_and_run(*args, **kwargs)
|
|
627
637
|
return out
|
|
628
638
|
|
|
629
639
|
# Run in PyNative mode.
|
|
630
640
|
if _pynative_executor.is_first_cell():
|
|
631
|
-
_pynative_executor.set_lazy_build(True)
|
|
632
641
|
_pynative_executor._optimizer = getattr(self, "optimizer", None)
|
|
633
642
|
_pynative_executor._top_cell = self
|
|
634
|
-
# There many Casts in parameter_broadcast. Enable
|
|
643
|
+
# There many Casts in parameter_broadcast. Enable build faster.
|
|
635
644
|
self._do_parameter_broadcast()
|
|
636
645
|
|
|
637
|
-
|
|
646
|
+
_check_args(args)
|
|
647
|
+
self._check_cell_flags_in_pynative()
|
|
638
648
|
|
|
639
649
|
if self.requires_grad:
|
|
640
650
|
_pynative_executor.set_grad_flag(True)
|
|
641
651
|
|
|
642
652
|
if self._dynamic_shape_inputs is not None:
|
|
643
|
-
self._check_compile_dynamic_shape(
|
|
653
|
+
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
644
654
|
|
|
645
655
|
try:
|
|
646
656
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
@@ -654,6 +664,12 @@ class Cell(Cell_):
|
|
|
654
664
|
output = output.data
|
|
655
665
|
return output
|
|
656
666
|
|
|
667
|
+
def _check_cell_flags_in_pynative(self):
|
|
668
|
+
"""Check the flags added to cell in pynative mode"""
|
|
669
|
+
if hasattr(self, "_func_graph_flags") and self._func_graph_flags.get("output_no_recompute"):
|
|
670
|
+
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
671
|
+
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
672
|
+
|
|
657
673
|
def _add_attr(self, name, value):
|
|
658
674
|
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
|
659
675
|
super(Cell, self)._add_attr(name, value)
|
|
@@ -841,7 +857,6 @@ class Cell(Cell_):
|
|
|
841
857
|
"""
|
|
842
858
|
logger.warning("'load_parameter_slice' function is deprecated.")
|
|
843
859
|
|
|
844
|
-
|
|
845
860
|
def set_parallel_input_with_inputs(self, *inputs):
|
|
846
861
|
"""
|
|
847
862
|
Slice inputs tensors by parallel strategies.
|
|
@@ -860,8 +875,8 @@ class Cell(Cell_):
|
|
|
860
875
|
Args:
|
|
861
876
|
inputs (tuple): Inputs of the Cell object.
|
|
862
877
|
|
|
863
|
-
|
|
864
|
-
This is an experimental
|
|
878
|
+
.. warning::
|
|
879
|
+
This is an experimental API that is subject to change or deletion.
|
|
865
880
|
|
|
866
881
|
Examples:
|
|
867
882
|
>>> import numpy as np
|
|
@@ -892,7 +907,7 @@ class Cell(Cell_):
|
|
|
892
907
|
if self._dynamic_shape_inputs:
|
|
893
908
|
ds.config.set_dynamic_shape(True)
|
|
894
909
|
if context._get_mode() == context.PYNATIVE_MODE:
|
|
895
|
-
_pynative_executor.set_dynamic_input(self
|
|
910
|
+
_pynative_executor.set_dynamic_input(self)
|
|
896
911
|
|
|
897
912
|
def get_inputs(self):
|
|
898
913
|
"""
|
|
@@ -901,35 +916,31 @@ class Cell(Cell_):
|
|
|
901
916
|
Returns:
|
|
902
917
|
inputs (tuple), Inputs of the Cell object.
|
|
903
918
|
|
|
904
|
-
|
|
905
|
-
This is an experimental
|
|
919
|
+
.. warning::
|
|
920
|
+
This is an experimental API that is subject to change or deletion.
|
|
906
921
|
"""
|
|
907
922
|
|
|
908
923
|
return self._dynamic_shape_inputs
|
|
909
924
|
|
|
910
|
-
def compile(self, *
|
|
925
|
+
def compile(self, *args, **kwargs):
|
|
911
926
|
"""
|
|
912
927
|
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
|
|
913
928
|
|
|
914
929
|
Args:
|
|
915
|
-
|
|
930
|
+
args (tuple): Args of the Cell object.
|
|
931
|
+
kwargs (dict): Kwargs of the Cell object.
|
|
916
932
|
"""
|
|
917
|
-
if self._dynamic_shape_inputs is None
|
|
918
|
-
_cell_graph_executor.compile(self,
|
|
919
|
-
jit_config_dict=self._jit_config_dict)
|
|
933
|
+
if self._dynamic_shape_inputs is None:
|
|
934
|
+
_cell_graph_executor.compile(self, phase=self.phase,
|
|
935
|
+
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
920
936
|
else:
|
|
921
|
-
self._check_compile_dynamic_shape(
|
|
922
|
-
if self.saved_dynamic_shape:
|
|
923
|
-
for i in range(len(self.saved_dynamic_shape)):
|
|
924
|
-
if self.saved_dynamic_shape[i].shape != self._dynamic_shape_inputs[i].shape:
|
|
925
|
-
return
|
|
926
|
-
|
|
937
|
+
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
927
938
|
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
928
939
|
_cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
|
|
929
|
-
jit_config_dict=self._jit_config_dict)
|
|
940
|
+
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
930
941
|
logger.debug("Compiled Graph with dynamic shape")
|
|
931
942
|
|
|
932
|
-
def compile_and_run(self, *
|
|
943
|
+
def compile_and_run(self, *args, **kwargs):
|
|
933
944
|
"""
|
|
934
945
|
Compile and run Cell, the input must be consistent with the input defined in construct.
|
|
935
946
|
|
|
@@ -937,15 +948,16 @@ class Cell(Cell_):
|
|
|
937
948
|
It is not recommended to call directly.
|
|
938
949
|
|
|
939
950
|
Args:
|
|
940
|
-
|
|
951
|
+
args (tuple): Args of the Cell object.
|
|
952
|
+
kwargs (dict): Kwargs of the Cell object.
|
|
941
953
|
|
|
942
954
|
Returns:
|
|
943
955
|
Object, the result of executing.
|
|
944
956
|
"""
|
|
945
|
-
self.compile(*
|
|
957
|
+
self.compile(*args, **kwargs)
|
|
946
958
|
|
|
947
|
-
|
|
948
|
-
return _cell_graph_executor(self, *
|
|
959
|
+
new_args = _get_args_for_run(self, args, kwargs)
|
|
960
|
+
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
949
961
|
|
|
950
962
|
def auto_parallel_compile_and_run(self):
|
|
951
963
|
"""
|
|
@@ -1027,8 +1039,12 @@ class Cell(Cell_):
|
|
|
1027
1039
|
|
|
1028
1040
|
Raises:
|
|
1029
1041
|
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
|
1042
|
+
TypeError: If type of `child_name` is not str.
|
|
1030
1043
|
TypeError: Child Cell's type is incorrect.
|
|
1031
1044
|
"""
|
|
1045
|
+
if not isinstance(child_name, str):
|
|
1046
|
+
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
1047
|
+
f"but got {type(child_name)}.")
|
|
1032
1048
|
if not child_name or '.' in child_name:
|
|
1033
1049
|
raise KeyError("For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1034
1050
|
"can not contain '.'")
|
|
@@ -1040,7 +1056,7 @@ class Cell(Cell_):
|
|
|
1040
1056
|
f"but got type {type(child_cell)}.")
|
|
1041
1057
|
self._cells[child_name] = child_cell
|
|
1042
1058
|
|
|
1043
|
-
def construct(self, *
|
|
1059
|
+
def construct(self, *args, **kwargs):
|
|
1044
1060
|
"""
|
|
1045
1061
|
Defines the computation to be performed. This method must be overridden by all subclasses.
|
|
1046
1062
|
|
|
@@ -1048,7 +1064,7 @@ class Cell(Cell_):
|
|
|
1048
1064
|
It is not supported currently that inputs contain both tuple and non-tuple types at same time.
|
|
1049
1065
|
|
|
1050
1066
|
Args:
|
|
1051
|
-
|
|
1067
|
+
args (tuple): Tuple of variable parameters.
|
|
1052
1068
|
kwargs (dict): Dictionary of variable keyword parameters.
|
|
1053
1069
|
|
|
1054
1070
|
Returns:
|
|
@@ -1200,6 +1216,7 @@ class Cell(Cell_):
|
|
|
1200
1216
|
param.is_init = False
|
|
1201
1217
|
param.name = prefix + name
|
|
1202
1218
|
|
|
1219
|
+
@jit_forbidden_register
|
|
1203
1220
|
def trainable_params(self, recurse=True):
|
|
1204
1221
|
"""
|
|
1205
1222
|
Returns all trainable parameters.
|
|
@@ -1214,6 +1231,7 @@ class Cell(Cell_):
|
|
|
1214
1231
|
"""
|
|
1215
1232
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1216
1233
|
|
|
1234
|
+
@jit_forbidden_register
|
|
1217
1235
|
def untrainable_params(self, recurse=True):
|
|
1218
1236
|
"""
|
|
1219
1237
|
Returns all untrainable parameters.
|
|
@@ -1228,6 +1246,7 @@ class Cell(Cell_):
|
|
|
1228
1246
|
"""
|
|
1229
1247
|
return list(filter(lambda x: not x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1230
1248
|
|
|
1249
|
+
@jit_forbidden_register
|
|
1231
1250
|
def get_parameters(self, expand=True):
|
|
1232
1251
|
"""
|
|
1233
1252
|
Returns an iterator over cell parameters.
|
|
@@ -1419,6 +1438,38 @@ class Cell(Cell_):
|
|
|
1419
1438
|
if "fp32" in flags and flags.get("fp32", False):
|
|
1420
1439
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1421
1440
|
|
|
1441
|
+
def apply(self, fn):
|
|
1442
|
+
"""
|
|
1443
|
+
Applies fn recursively to every subcell (as returned by .cells()) as well as self.
|
|
1444
|
+
Typical use includes initializing the parameters of a model.
|
|
1445
|
+
|
|
1446
|
+
Args:
|
|
1447
|
+
fn (function): function to be applied to each subcell.
|
|
1448
|
+
|
|
1449
|
+
Returns:
|
|
1450
|
+
Cell, self.
|
|
1451
|
+
|
|
1452
|
+
Examples:
|
|
1453
|
+
>>> import mindspore.nn as nn
|
|
1454
|
+
>>> from mindspore.common.initializer import initializer, One
|
|
1455
|
+
>>> net = nn.SequentialCell(nn.Dense(2, 2), nn.Dense(2, 2))
|
|
1456
|
+
>>> def func(cell):
|
|
1457
|
+
... if isinstance(cell, nn.Dense):
|
|
1458
|
+
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
|
1459
|
+
>>> net.apply(func)
|
|
1460
|
+
SequentialCell<
|
|
1461
|
+
(0): Dense<input_channels=2, output_channels=2, has_bias=True>
|
|
1462
|
+
(1): Dense<input_channels=2, output_channels=2, has_bias=True>
|
|
1463
|
+
>
|
|
1464
|
+
>>> print(net[0].weight.asnumpy())
|
|
1465
|
+
[[1. 1.]
|
|
1466
|
+
[1. 1.]]
|
|
1467
|
+
"""
|
|
1468
|
+
for cell in self.cells():
|
|
1469
|
+
cell.apply(fn)
|
|
1470
|
+
fn(self)
|
|
1471
|
+
return self
|
|
1472
|
+
|
|
1422
1473
|
def add_flags(self, **flags):
|
|
1423
1474
|
"""
|
|
1424
1475
|
Add customized attributes for cell.
|
|
@@ -1473,7 +1524,7 @@ class Cell(Cell_):
|
|
|
1473
1524
|
Add cast on all inputs of cell and child cells to run with certain float type.
|
|
1474
1525
|
|
|
1475
1526
|
If `dst_type` is `mindspore.dtype.float16`, all the inputs of Cell, including input, Parameter and Tensor, will
|
|
1476
|
-
be cast to float16. Please refer to the usage in source code of :func:`mindspore.build_train_network`.
|
|
1527
|
+
be cast to float16. Please refer to the usage in source code of :func:`mindspore.amp.build_train_network`.
|
|
1477
1528
|
|
|
1478
1529
|
Note:
|
|
1479
1530
|
Multiple calls will overwrite.
|
|
@@ -1505,8 +1556,10 @@ class Cell(Cell_):
|
|
|
1505
1556
|
"but got {}.".format(dst_type))
|
|
1506
1557
|
if dst_type == mstype.float16:
|
|
1507
1558
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
|
1559
|
+
self.to_float_fp16 = True
|
|
1508
1560
|
else:
|
|
1509
1561
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1562
|
+
self.to_float_fp16 = False
|
|
1510
1563
|
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
|
|
1511
1564
|
self._add_init_args(**flags)
|
|
1512
1565
|
return self
|
|
@@ -1517,7 +1570,7 @@ class Cell(Cell_):
|
|
|
1517
1570
|
accelerate the algorithm in the algorithm library.
|
|
1518
1571
|
|
|
1519
1572
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1520
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.0
|
|
1573
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/python/mindspore/boost>`_.
|
|
1521
1574
|
|
|
1522
1575
|
Note:
|
|
1523
1576
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1562,6 +1615,10 @@ class Cell(Cell_):
|
|
|
1562
1615
|
for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
|
|
1563
1616
|
set to true, the training branch will be executed, otherwise another branch.
|
|
1564
1617
|
|
|
1618
|
+
Note:
|
|
1619
|
+
When execute function Model.train(), framework will call Cell.set_train(True).
|
|
1620
|
+
When execute function Model.eval(), framework will call Cell.set_train(False).
|
|
1621
|
+
|
|
1565
1622
|
Args:
|
|
1566
1623
|
mode (bool): Specifies whether the model is training. Default: True.
|
|
1567
1624
|
|
|
@@ -1605,6 +1662,11 @@ class Cell(Cell_):
|
|
|
1605
1662
|
logger.warning("For Cell, jit config can only be set once, ignore this setting.")
|
|
1606
1663
|
else:
|
|
1607
1664
|
self._jit_config_dict = jit_config.jit_config_dict
|
|
1665
|
+
enable_ge = os.getenv("MS_ENABLE_GE") == '1'
|
|
1666
|
+
enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
|
|
1667
|
+
if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
|
|
1668
|
+
raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
|
|
1669
|
+
format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
|
1608
1670
|
|
|
1609
1671
|
def flatten_weights(self, fusion_size=0):
|
|
1610
1672
|
"""
|
|
@@ -2017,9 +2079,6 @@ class Cell(Cell_):
|
|
|
2017
2079
|
"""
|
|
2018
2080
|
Set the cell recomputed.
|
|
2019
2081
|
"""
|
|
2020
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
2021
|
-
raise TypeError("Recompute is not supported in pynative mode currently, you can use "
|
|
2022
|
-
"'context.set_context(mode=context.GRAPH_MODE)' to set graph mode.")
|
|
2023
2082
|
Validator.check_bool(mode)
|
|
2024
2083
|
Validator.check_bool(output_recompute)
|
|
2025
2084
|
if not self._has_config_recompute:
|
|
@@ -2122,13 +2181,13 @@ class Cell(Cell_):
|
|
|
2122
2181
|
"""
|
|
2123
2182
|
Set the label for all operators in this cell.
|
|
2124
2183
|
This label tells MindSpore compiler on which process this cell should be launched.
|
|
2125
|
-
And each process's identical label consists of input
|
|
2184
|
+
And each process's identical label consists of input `role` and `rank_id`.
|
|
2126
2185
|
So by setting different cells with different labels, which will be launched on different processes,
|
|
2127
2186
|
users can launch a distributed training or predicting job.
|
|
2128
2187
|
|
|
2129
2188
|
Note:
|
|
2130
2189
|
- This method is effective only after
|
|
2131
|
-
|
|
2190
|
+
`mindspore.communication.init()` is called for dynamic cluster building.
|
|
2132
2191
|
|
|
2133
2192
|
Args:
|
|
2134
2193
|
role (str): The role of the process on which this cell will be launched.
|
|
@@ -2147,35 +2206,57 @@ class Cell(Cell_):
|
|
|
2147
2206
|
for op in all_ops:
|
|
2148
2207
|
op.place(role, rank_id)
|
|
2149
2208
|
|
|
2150
|
-
def
|
|
2209
|
+
def _check_dynamic_tensor(self, set_input, net_input, index):
|
|
2151
2210
|
"""
|
|
2152
|
-
Check if
|
|
2211
|
+
Check if tensor is correctly set for dynamic shape.
|
|
2153
2212
|
|
|
2154
2213
|
Args:
|
|
2155
|
-
|
|
2214
|
+
set_input (Tensor): Tensor set for dynamic shape.
|
|
2215
|
+
net_input (Tensor): Input tensor of the Cell object.
|
|
2216
|
+
index (int): Tensor index for set inputs.
|
|
2217
|
+
"""
|
|
2218
|
+
if not isinstance(net_input, Tensor):
|
|
2219
|
+
raise TypeError(
|
|
2220
|
+
f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
|
|
2221
|
+
if set_input.dtype != net_input.dtype:
|
|
2222
|
+
raise ValueError(
|
|
2223
|
+
f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
|
|
2224
|
+
f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
|
2225
|
+
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
|
2226
|
+
raise ValueError(
|
|
2227
|
+
f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
|
|
2228
|
+
f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
|
|
2229
|
+
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
|
2230
|
+
raise ValueError(
|
|
2231
|
+
f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
|
|
2232
|
+
f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.")
|
|
2233
|
+
|
|
2234
|
+
def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
|
|
2156
2235
|
"""
|
|
2157
|
-
|
|
2158
|
-
|
|
2159
|
-
|
|
2160
|
-
|
|
2161
|
-
|
|
2162
|
-
|
|
2236
|
+
Check if graph has been compiled with dynamic shape.
|
|
2237
|
+
|
|
2238
|
+
Args:
|
|
2239
|
+
net_inputs (tuple): Inputs of the Cell object.
|
|
2240
|
+
"""
|
|
2241
|
+
set_inputs_len = len(set_inputs)
|
|
2242
|
+
net_inputs_len = len(net_inputs)
|
|
2243
|
+
if set_inputs_len != net_inputs_len:
|
|
2244
|
+
raise ValueError("The length of 'set_inputs' must be equal to network's inputs, "
|
|
2245
|
+
f"but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
|
|
2246
|
+
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
2163
2247
|
if isinstance(set_input, Tensor):
|
|
2164
|
-
|
|
2248
|
+
self._check_dynamic_tensor(set_input, net_input, index)
|
|
2249
|
+
elif isinstance(set_input, (tuple, list)):
|
|
2250
|
+
if not isinstance(net_input, (tuple, list)):
|
|
2165
2251
|
raise TypeError(
|
|
2166
|
-
f"The {index + 1}th input type of 'set_inputs' must be
|
|
2167
|
-
|
|
2168
|
-
|
|
2169
|
-
|
|
2170
|
-
|
|
2171
|
-
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
|
2172
|
-
raise ValueError(
|
|
2173
|
-
f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
|
|
2174
|
-
f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
|
|
2175
|
-
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
|
2252
|
+
f"The {index + 1}th input type of 'set_inputs' must be tuple or list, "
|
|
2253
|
+
f"but got {type(net_input)}.")
|
|
2254
|
+
self._check_compile_dynamic_shape(set_input, net_input)
|
|
2255
|
+
else:
|
|
2256
|
+
if net_input != set_input:
|
|
2176
2257
|
raise ValueError(
|
|
2177
|
-
f"The {index + 1}th input
|
|
2178
|
-
f"
|
|
2258
|
+
f"The {index + 1}th input of 'set_inputs' must be the same with network's input, but got "
|
|
2259
|
+
f"set_inputs: {set_input} and network's input: {net_input}.")
|
|
2179
2260
|
|
|
2180
2261
|
|
|
2181
2262
|
class GraphCell(Cell):
|
|
@@ -2191,10 +2272,10 @@ class GraphCell(Cell):
|
|
|
2191
2272
|
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
|
|
2192
2273
|
If the parameter exists in the graph according to the name, update it's value.
|
|
2193
2274
|
If the parameter does not exist, ignore it. Default: None.
|
|
2194
|
-
|
|
2195
|
-
protection, which can refer to :func:`mindspore.obfuscate_model`. If the input
|
|
2196
|
-
func_graph loaded from a mindir file obfuscated
|
|
2197
|
-
|
|
2275
|
+
obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
|
|
2276
|
+
used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
|
|
2277
|
+
a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
|
|
2278
|
+
provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: None.
|
|
2198
2279
|
|
|
2199
2280
|
Raises:
|
|
2200
2281
|
TypeError: If the `graph` is not a FuncGraph.
|
|
@@ -2210,7 +2291,8 @@ class GraphCell(Cell):
|
|
|
2210
2291
|
>>> import mindspore as ms
|
|
2211
2292
|
>>> import mindspore.nn as nn
|
|
2212
2293
|
>>> from mindspore import Tensor
|
|
2213
|
-
>>>
|
|
2294
|
+
>>> from mindspore import context
|
|
2295
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
2214
2296
|
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
|
|
2215
2297
|
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2216
2298
|
>>> ms.export(net, input, file_name="net", file_format="MINDIR")
|
|
@@ -2223,21 +2305,22 @@ class GraphCell(Cell):
|
|
|
2223
2305
|
[4. 6. 4.]]]]
|
|
2224
2306
|
"""
|
|
2225
2307
|
|
|
2226
|
-
def __init__(self, graph, params_init=None,
|
|
2308
|
+
def __init__(self, graph, params_init=None, obf_random_seed=None):
|
|
2227
2309
|
super(GraphCell, self).__init__(auto_prefix=True)
|
|
2228
2310
|
if not isinstance(graph, FuncGraph):
|
|
2229
2311
|
raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
|
|
2230
2312
|
f"but got type {type(graph)}.")
|
|
2231
2313
|
self.graph = graph
|
|
2232
|
-
self.
|
|
2233
|
-
if
|
|
2234
|
-
if not isinstance(
|
|
2235
|
-
raise TypeError("'
|
|
2314
|
+
self.obf_random_seed = obf_random_seed
|
|
2315
|
+
if obf_random_seed is not None:
|
|
2316
|
+
if not isinstance(obf_random_seed, int):
|
|
2317
|
+
raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
|
|
2236
2318
|
int_64_max = 9223372036854775807
|
|
2237
|
-
if
|
|
2319
|
+
if obf_random_seed <= 0 or obf_random_seed > int_64_max:
|
|
2238
2320
|
raise ValueError(
|
|
2239
|
-
"'
|
|
2240
|
-
"but got {}.".format(int_64_max,
|
|
2321
|
+
"'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
|
|
2322
|
+
"but got {}.".format(int_64_max, obf_random_seed))
|
|
2323
|
+
self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
|
|
2241
2324
|
params_init = {} if params_init is None else params_init
|
|
2242
2325
|
if not isinstance(params_init, dict):
|
|
2243
2326
|
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
|
|
@@ -2254,24 +2337,13 @@ class GraphCell(Cell):
|
|
|
2254
2337
|
def construct(self, *inputs):
|
|
2255
2338
|
return self.graph(*inputs)
|
|
2256
2339
|
|
|
2257
|
-
def __call__(self, *
|
|
2340
|
+
def __call__(self, *args, **kwargs):
|
|
2258
2341
|
self.phase = "graph_load_from_mindir"
|
|
2259
2342
|
self._add_attr("graph_load_from_mindir", self.graph)
|
|
2260
|
-
if not self.
|
|
2261
|
-
return self.compile_and_run(*
|
|
2262
|
-
|
|
2263
|
-
return self.compile_and_run(*
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
def _obf_appended_inputs(obf_password):
|
|
2267
|
-
seed_max = 2 ** 32 - 1
|
|
2268
|
-
int_max = 2 ** 31 - 1
|
|
2269
|
-
numpy.random.seed(obf_password % seed_max)
|
|
2270
|
-
append_password = numpy.random.randint(int_max)
|
|
2271
|
-
obf_password %= int_max
|
|
2272
|
-
append_input_1 = Tensor((numpy.ones((1, 1)) * obf_password).astype(numpy.int32))
|
|
2273
|
-
append_input_2 = Tensor((numpy.ones((1, 1)) * append_password).astype(numpy.int32))
|
|
2274
|
-
return append_input_1, append_input_2
|
|
2343
|
+
if not self.obf_random_seed:
|
|
2344
|
+
return self.compile_and_run(*args, **kwargs)
|
|
2345
|
+
append_input = Tensor((numpy.ones((1, 1)) * self._branch_control_input).astype(numpy.int32))
|
|
2346
|
+
return self.compile_and_run(*args, append_input, **kwargs)
|
|
2275
2347
|
|
|
2276
2348
|
|
|
2277
2349
|
def _check_param_list_tuple(value):
|
mindspore/nn/dynamic_lr.py
CHANGED