mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -24,6 +24,7 @@ import time
|
|
|
24
24
|
import ast
|
|
25
25
|
import inspect
|
|
26
26
|
import importlib
|
|
27
|
+
import hashlib
|
|
27
28
|
from collections import OrderedDict
|
|
28
29
|
from functools import wraps
|
|
29
30
|
import numpy as np
|
|
@@ -35,15 +36,17 @@ from mindspore.common.tensor import Tensor as PythonTensor
|
|
|
35
36
|
from mindspore.common.sparse_tensor import CSRTensor as PythonCSRTensor
|
|
36
37
|
from mindspore.common.sparse_tensor import COOTensor as PythonCOOTensor
|
|
37
38
|
from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
38
|
-
from mindspore._c_expression import GraphExecutor_, Tensor,
|
|
39
|
+
from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor, COOTensor, \
|
|
39
40
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
40
|
-
_ms_memory_recycle
|
|
41
|
+
_ms_memory_recycle, _bind_device_ctx
|
|
41
42
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
42
43
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
|
|
43
44
|
_get_pipeline_stages, _is_in_auto_parallel_mode
|
|
44
|
-
from mindspore
|
|
45
|
+
from mindspore import _checkparam as Validator
|
|
46
|
+
from mindspore._checkparam import is_stub_tensor
|
|
45
47
|
from mindspore.common._utils import is_shape_unknown
|
|
46
48
|
from mindspore.common.mutable import mutable
|
|
49
|
+
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
47
50
|
|
|
48
51
|
# store ms_function class compiled pipeline cache
|
|
49
52
|
ms_compile_cache = set()
|
|
@@ -64,6 +67,8 @@ def _convert_python_data(data):
|
|
|
64
67
|
Returns:
|
|
65
68
|
data, a data convert C++ to python
|
|
66
69
|
"""
|
|
70
|
+
if isinstance(data, Tensor) and data.adapter_flag:
|
|
71
|
+
return ms_adapter_registry.tensor(data)
|
|
67
72
|
if isinstance(data, Tensor) and not isinstance(data, PythonTensor):
|
|
68
73
|
return PythonTensor(data, internal=True)
|
|
69
74
|
if isinstance(data, CSRTensor) and not isinstance(data, PythonCSRTensor):
|
|
@@ -102,7 +107,8 @@ def _wrap_func(fn):
|
|
|
102
107
|
|
|
103
108
|
def _check_all_tensor(sequence):
|
|
104
109
|
for element in sequence:
|
|
105
|
-
if not isinstance(element, Tensor) and not (
|
|
110
|
+
if not isinstance(element, Tensor) and not is_stub_tensor(element) and not (isinstance(element, tuple)
|
|
111
|
+
and _check_all_tensor(element)):
|
|
106
112
|
return False
|
|
107
113
|
return True
|
|
108
114
|
|
|
@@ -116,28 +122,28 @@ def _handle_func_args(func, *args, **kwargs):
|
|
|
116
122
|
bound_arguments.apply_defaults()
|
|
117
123
|
args = bound_arguments.args
|
|
118
124
|
kwargs = bound_arguments.kwargs
|
|
119
|
-
# After apply_defaults, kwargs should be empty here.
|
|
120
|
-
if kwargs:
|
|
121
|
-
raise ValueError(f"Failed to handle kwargs of {func.__name__}. Maybe you pass wrong arguments, "
|
|
122
|
-
f"or there is a key in kwargs that is not used as a function argument, "
|
|
123
|
-
f"args: {args}, kwargs: {kwargs}")
|
|
124
125
|
|
|
125
126
|
positional_args = 0
|
|
126
127
|
default_args = 0
|
|
128
|
+
has_var = False
|
|
127
129
|
for value in inspect.signature(func).parameters.values():
|
|
128
130
|
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
|
|
129
|
-
|
|
131
|
+
has_var = True
|
|
130
132
|
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
|
131
133
|
if value.default is inspect.Parameter.empty:
|
|
132
134
|
positional_args += 1
|
|
133
135
|
else:
|
|
134
136
|
default_args += 1
|
|
137
|
+
|
|
138
|
+
if has_var:
|
|
139
|
+
return args, kwargs
|
|
140
|
+
|
|
135
141
|
if len(args) < positional_args:
|
|
136
142
|
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
|
|
137
143
|
if len(args) > positional_args + default_args:
|
|
138
144
|
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
|
|
139
145
|
f"default argument, total {positional_args + default_args}, but got {len(args)}.")
|
|
140
|
-
return args
|
|
146
|
+
return args, kwargs
|
|
141
147
|
|
|
142
148
|
|
|
143
149
|
sys_path = list(sys.path)
|
|
@@ -163,7 +169,8 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
|
|
|
163
169
|
for node in ast.iter_child_nodes(root):
|
|
164
170
|
module_name = ""
|
|
165
171
|
if isinstance(node, ast.ImportFrom):
|
|
166
|
-
|
|
172
|
+
if node.module is not None:
|
|
173
|
+
module_name = node.module
|
|
167
174
|
if node.level == 1:
|
|
168
175
|
module_name = "." + module_name
|
|
169
176
|
elif not isinstance(node, ast.Import):
|
|
@@ -235,25 +242,42 @@ def _get_parameter_layout():
|
|
|
235
242
|
return layout
|
|
236
243
|
|
|
237
244
|
|
|
238
|
-
def
|
|
239
|
-
"""
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
245
|
+
def _handle_arg(obj, arg):
|
|
246
|
+
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
247
|
+
if isinstance(arg, PythonTensor):
|
|
248
|
+
if arg.has_init:
|
|
249
|
+
arg.init_data()
|
|
250
|
+
if not arg.const_arg:
|
|
251
|
+
return arg
|
|
252
|
+
elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
|
|
253
|
+
return arg
|
|
254
|
+
elif hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__"):
|
|
255
|
+
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
256
|
+
if isinstance(arg, list) and not arg:
|
|
257
|
+
return None
|
|
258
|
+
return arg
|
|
259
|
+
elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)):
|
|
260
|
+
return arg
|
|
261
|
+
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
|
|
262
|
+
_check_all_tensor(arg):
|
|
263
|
+
return arg
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _get_args_for_run(obj, args, kwargs):
|
|
268
|
+
"""Get the actual input args and kwargs for runtime."""
|
|
269
|
+
new_args = []
|
|
270
|
+
for arg in args:
|
|
271
|
+
new_arg = _handle_arg(obj, arg)
|
|
272
|
+
if new_arg is not None:
|
|
273
|
+
new_args.append(new_arg)
|
|
274
|
+
|
|
275
|
+
for _, value in kwargs.items():
|
|
276
|
+
new_value = _handle_arg(obj, value)
|
|
277
|
+
if new_value is not None:
|
|
278
|
+
new_args.append(new_value)
|
|
279
|
+
|
|
280
|
+
return new_args
|
|
257
281
|
|
|
258
282
|
|
|
259
283
|
class _MindsporeFunctionExecutor:
|
|
@@ -291,24 +315,42 @@ class _MindsporeFunctionExecutor:
|
|
|
291
315
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
292
316
|
|
|
293
317
|
@_wrap_func
|
|
294
|
-
def __call__(self, *args):
|
|
318
|
+
def __call__(self, *args, **kwargs):
|
|
295
319
|
args_list = args
|
|
296
320
|
if self.obj is not None:
|
|
297
321
|
args_list = args_list[1:]
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
322
|
+
phase = ""
|
|
323
|
+
try:
|
|
324
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
325
|
+
_pynative_executor.set_ms_function_compile_status(True, phase)
|
|
326
|
+
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
327
|
+
_pynative_executor.set_ms_function_compile_status(False, phase)
|
|
328
|
+
else:
|
|
329
|
+
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
330
|
+
except Exception as err:
|
|
331
|
+
_pynative_executor.clear_res()
|
|
332
|
+
raise err
|
|
333
|
+
|
|
301
334
|
if context.get_context("precompile_only"):
|
|
302
335
|
return None
|
|
303
|
-
|
|
336
|
+
|
|
337
|
+
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
304
338
|
output = self._graph_executor(tuple(new_inputs), phase)
|
|
305
339
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
306
|
-
_pynative_executor.set_graph_phase(phase)
|
|
307
340
|
output = _pynative_executor.grad_ms_function(output, *new_inputs)
|
|
308
341
|
|
|
342
|
+
enable_ge = os.getenv("MS_ENABLE_GE") == "1"
|
|
343
|
+
if enable_ge and self.jit_config_dict is None:
|
|
344
|
+
raise RuntimeError("GE and jit_level=O3 should be used together, but jit_config is None.")
|
|
345
|
+
if self.jit_config_dict:
|
|
346
|
+
enable_jit_level_o3 = self.jit_config_dict.get('jit_level') == "O3"
|
|
347
|
+
if (enable_ge and not enable_jit_level_o3) or (not enable_ge and enable_jit_level_o3):
|
|
348
|
+
raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jit_level={}".
|
|
349
|
+
format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
|
350
|
+
|
|
309
351
|
return output
|
|
310
352
|
|
|
311
|
-
def compile(self,
|
|
353
|
+
def compile(self, method_name, *args, **kwargs):
|
|
312
354
|
"""Returns pipeline for the given args."""
|
|
313
355
|
# Check whether hook function registered on Cell object.
|
|
314
356
|
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
|
@@ -317,12 +359,12 @@ class _MindsporeFunctionExecutor:
|
|
|
317
359
|
f"If you want to use hook function, please use context.set_context to set "
|
|
318
360
|
f"pynative mode and remove 'jit' decorator.")
|
|
319
361
|
# Chose dynamic shape tensors or actual input tensors as compile args.
|
|
320
|
-
compile_args = self._generate_compile_args(
|
|
362
|
+
compile_args = self._generate_compile_args(args)
|
|
321
363
|
# Restore the mutable attr for every arg.
|
|
322
|
-
compile_args = _restore_mutable_attr(
|
|
364
|
+
compile_args = _restore_mutable_attr(args, compile_args)
|
|
323
365
|
|
|
324
366
|
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
|
|
325
|
-
|
|
367
|
+
str(self.fn.__code__.co_firstlineno)
|
|
326
368
|
if _pynative_executor.grad_flag():
|
|
327
369
|
generate_name = generate_name + ".grad"
|
|
328
370
|
if _is_pynative_parallel():
|
|
@@ -348,7 +390,7 @@ class _MindsporeFunctionExecutor:
|
|
|
348
390
|
self.enable_tuple_broaden = self.obj.enable_tuple_broaden
|
|
349
391
|
|
|
350
392
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
351
|
-
key = self._graph_executor.generate_arguments_key(compile_args, self.enable_tuple_broaden)
|
|
393
|
+
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
352
394
|
phase = generate_name + '.' + str(key)
|
|
353
395
|
if phase in ms_compile_cache:
|
|
354
396
|
return phase
|
|
@@ -359,11 +401,11 @@ class _MindsporeFunctionExecutor:
|
|
|
359
401
|
self._graph_executor.set_jit_config(self.jit_config_dict)
|
|
360
402
|
|
|
361
403
|
if self.obj is None:
|
|
362
|
-
is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
|
|
404
|
+
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
|
|
363
405
|
else:
|
|
364
406
|
if isinstance(self.obj, ms.nn.Cell):
|
|
365
407
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
366
|
-
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
|
|
408
|
+
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
|
|
367
409
|
|
|
368
410
|
if not is_compile:
|
|
369
411
|
raise RuntimeError("Executor compile failed.")
|
|
@@ -393,47 +435,51 @@ class _MindsporeFunctionExecutor:
|
|
|
393
435
|
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
394
436
|
compile_args = args_list
|
|
395
437
|
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
|
396
|
-
if isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
|
|
438
|
+
if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
|
|
397
439
|
compile_args = self.obj.get_inputs()
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
440
|
+
if len(compile_args) != len(args_list):
|
|
441
|
+
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
442
|
+
f"dynamic shape tensors: {len(compile_args)}.")
|
|
443
|
+
for i, elem in enumerate(compile_args):
|
|
444
|
+
if isinstance(elem, PythonTensor):
|
|
445
|
+
Validator.check_dynamic_shape(compile_args[i], args_list[i], i)
|
|
446
|
+
|
|
401
447
|
# Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
|
|
402
448
|
if self.input_signature is not None:
|
|
403
449
|
if not isinstance(self.input_signature, (tuple, list)):
|
|
404
450
|
self.input_signature = (self.input_signature,)
|
|
405
451
|
self.input_signature = list(self.input_signature)
|
|
406
452
|
dyn_shape = False
|
|
407
|
-
for
|
|
408
|
-
|
|
409
|
-
|
|
453
|
+
for i, elem in enumerate(self.input_signature):
|
|
454
|
+
if isinstance(elem, PythonTensor) and is_shape_unknown(elem.shape):
|
|
455
|
+
Validator.check_dynamic_shape(self.input_signature[i], args_list[i], i)
|
|
410
456
|
dyn_shape = True
|
|
411
|
-
if
|
|
412
|
-
if not verify_inputs_signature(self.input_signature, args_list):
|
|
413
|
-
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
414
|
-
else:
|
|
457
|
+
if dyn_shape:
|
|
415
458
|
# Checkout whether the `sens` has been added to args_list.
|
|
416
459
|
if len(self.input_signature) == len(args_list) - 1:
|
|
417
|
-
logger.warning(f"The number of actual input args
|
|
418
|
-
f"of input_signature args
|
|
419
|
-
f"be
|
|
460
|
+
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
461
|
+
f"of input_signature args '{len(self.input_signature)}'. The last actual args may "
|
|
462
|
+
f"be 'sens' and added it to compile args.")
|
|
420
463
|
self.input_signature.append(args_list[-1])
|
|
421
|
-
Validator.check_dynamic_shape(self.input_signature, args_list)
|
|
422
464
|
compile_args = tuple(self.input_signature)
|
|
423
|
-
_pynative_executor.set_dynamic_input(self.obj
|
|
465
|
+
_pynative_executor.set_dynamic_input(self.obj)
|
|
466
|
+
else:
|
|
467
|
+
if not verify_inputs_signature(self.input_signature, args_list):
|
|
468
|
+
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
424
469
|
return compile_args
|
|
425
470
|
|
|
426
|
-
def _generate_run_args(self, args_list):
|
|
471
|
+
def _generate_run_args(self, args_list, kwargs):
|
|
427
472
|
"""
|
|
428
473
|
Generate input args, which are required for running.
|
|
429
474
|
|
|
430
475
|
Args:
|
|
431
476
|
args_list (Tuple): Actual input args.
|
|
477
|
+
kwargs (Dict): Actual input kwargs.
|
|
432
478
|
|
|
433
479
|
Returns:
|
|
434
480
|
new_inputs, new input args, which are required for running.
|
|
435
481
|
"""
|
|
436
|
-
return _get_args_for_run(self, args_list)
|
|
482
|
+
return _get_args_for_run(self, args_list, kwargs)
|
|
437
483
|
|
|
438
484
|
|
|
439
485
|
# The attributes used to identify a given object.
|
|
@@ -551,14 +597,15 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
|
551
597
|
if os.getenv("MS_JIT") == '0':
|
|
552
598
|
return func(*args, **kwargs)
|
|
553
599
|
|
|
554
|
-
args = _handle_func_args(func, *args, **kwargs)
|
|
600
|
+
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
601
|
+
|
|
555
602
|
process_obj = None
|
|
556
603
|
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
557
604
|
process_obj = args[0]
|
|
558
605
|
# only the function or cell instance wrapped by shard will fall into this branch
|
|
559
606
|
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
|
560
607
|
process_obj = hash_args
|
|
561
|
-
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
|
|
608
|
+
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args, **kwargs)
|
|
562
609
|
return out
|
|
563
610
|
|
|
564
611
|
return staging_specialize
|
|
@@ -648,7 +695,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|
|
648
695
|
... closure_fn(inputs, func)
|
|
649
696
|
"""
|
|
650
697
|
|
|
651
|
-
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
|
|
698
|
+
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
|
|
652
699
|
"Please use 'mindspore.jit' instead.")
|
|
653
700
|
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
|
|
654
701
|
|
|
@@ -719,7 +766,7 @@ def _add_flags(fn=None, **flags):
|
|
|
719
766
|
return ret
|
|
720
767
|
|
|
721
768
|
|
|
722
|
-
def
|
|
769
|
+
def _no_recursive(callable_obj):
|
|
723
770
|
"""
|
|
724
771
|
Method or function decorator for ignoring recursive check.
|
|
725
772
|
|
|
@@ -794,7 +841,7 @@ def ms_class(cls):
|
|
|
794
841
|
20
|
|
795
842
|
"""
|
|
796
843
|
|
|
797
|
-
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
|
|
844
|
+
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
|
|
798
845
|
"Please use 'mindspore.jit_class' instead.")
|
|
799
846
|
|
|
800
847
|
# Check if cls is of type class.
|
|
@@ -821,8 +868,8 @@ def jit_class(cls):
|
|
|
821
868
|
Class.
|
|
822
869
|
|
|
823
870
|
Raises:
|
|
824
|
-
TypeError: If jit_class is used for non-class types or nn.Cell.
|
|
825
|
-
AttributeError: If the private attributes or magic methods of the class decorated with jit_class is called.
|
|
871
|
+
TypeError: If `jit_class` is used for non-class types or nn.Cell.
|
|
872
|
+
AttributeError: If the private attributes or magic methods of the class decorated with `jit_class` is called.
|
|
826
873
|
|
|
827
874
|
Supported Platforms:
|
|
828
875
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -864,6 +911,25 @@ def jit_class(cls):
|
|
|
864
911
|
return cls
|
|
865
912
|
|
|
866
913
|
|
|
914
|
+
def set_adapter_config(config):
|
|
915
|
+
"""
|
|
916
|
+
Register configuration information for MSAdapter.
|
|
917
|
+
|
|
918
|
+
Args:
|
|
919
|
+
config (dict): Configuration information.
|
|
920
|
+
"""
|
|
921
|
+
if not isinstance(config, dict):
|
|
922
|
+
raise TypeError(f"The input argument of 'set_adapter_config' should be a dict, but got {config}.")
|
|
923
|
+
for key, value in config.items():
|
|
924
|
+
if key == "Tensor":
|
|
925
|
+
setattr(value, "__adapter_tensor__", True)
|
|
926
|
+
ms_adapter_registry.register_tensor(value)
|
|
927
|
+
elif key == "convert_object_map":
|
|
928
|
+
ms_adapter_registry.register_convert_map(value)
|
|
929
|
+
else:
|
|
930
|
+
raise ValueError(f"Unsupported key in adapter config: {key}")
|
|
931
|
+
|
|
932
|
+
|
|
867
933
|
def _function_forbid_reuse(func):
|
|
868
934
|
if not inspect.isfunction(func):
|
|
869
935
|
raise TypeError(f'Decorator _function_forbid_reuse can only be used for function type, but got {func}.')
|
|
@@ -937,7 +1003,6 @@ class _PyNativeExecutor:
|
|
|
937
1003
|
self._executor = PyNativeExecutor_.get_instance()
|
|
938
1004
|
self._executor.set_py_exe_path(sys.executable)
|
|
939
1005
|
self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
|
940
|
-
self._optimizer = None
|
|
941
1006
|
self._top_cell = None
|
|
942
1007
|
|
|
943
1008
|
def __call__(self):
|
|
@@ -976,6 +1041,19 @@ class _PyNativeExecutor:
|
|
|
976
1041
|
"""
|
|
977
1042
|
return self._executor.real_run_op(*args)
|
|
978
1043
|
|
|
1044
|
+
def run_op_async(self, prim, args):
|
|
1045
|
+
"""
|
|
1046
|
+
Run single op async.
|
|
1047
|
+
|
|
1048
|
+
Args:
|
|
1049
|
+
prim (Primitive): Op primitive
|
|
1050
|
+
args (tuple): input arguments.
|
|
1051
|
+
|
|
1052
|
+
Return:
|
|
1053
|
+
StubNode, result of run op.
|
|
1054
|
+
"""
|
|
1055
|
+
return self._executor.run_op_async(prim, args)
|
|
1056
|
+
|
|
979
1057
|
def new_graph(self, obj, *args, **kwargs):
|
|
980
1058
|
"""
|
|
981
1059
|
Initialize resources for building forward and backward graph.
|
|
@@ -1005,7 +1083,7 @@ class _PyNativeExecutor:
|
|
|
1005
1083
|
"""
|
|
1006
1084
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
|
1007
1085
|
|
|
1008
|
-
def check_run(self, grad, obj, grad_hash_id, *args, **kwargs):
|
|
1086
|
+
def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
|
|
1009
1087
|
"""
|
|
1010
1088
|
Whether the forward graph need to construct.
|
|
1011
1089
|
|
|
@@ -1019,7 +1097,7 @@ class _PyNativeExecutor:
|
|
|
1019
1097
|
Return:
|
|
1020
1098
|
bool, specifies whether the forward graph need to construct.
|
|
1021
1099
|
"""
|
|
1022
|
-
return self._executor.check_run(grad, obj, grad_hash_id, *args, *(kwargs.values()))
|
|
1100
|
+
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, *(kwargs.values()))
|
|
1023
1101
|
|
|
1024
1102
|
def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
|
|
1025
1103
|
"""
|
|
@@ -1082,18 +1160,6 @@ class _PyNativeExecutor:
|
|
|
1082
1160
|
"""
|
|
1083
1161
|
return self._executor.grad_ms_function(output, *args)
|
|
1084
1162
|
|
|
1085
|
-
def set_graph_phase(self, phase):
|
|
1086
|
-
"""
|
|
1087
|
-
Set the phase of cell/function instance.
|
|
1088
|
-
|
|
1089
|
-
Args:
|
|
1090
|
-
phase (str): The phase of cell/function instance.
|
|
1091
|
-
|
|
1092
|
-
Return:
|
|
1093
|
-
None.
|
|
1094
|
-
"""
|
|
1095
|
-
self._executor.set_graph_phase(phase)
|
|
1096
|
-
|
|
1097
1163
|
def grad_flag(self):
|
|
1098
1164
|
"""
|
|
1099
1165
|
The flag of building grad graph.
|
|
@@ -1115,29 +1181,29 @@ class _PyNativeExecutor:
|
|
|
1115
1181
|
"""
|
|
1116
1182
|
self._executor.set_grad_flag(flag)
|
|
1117
1183
|
|
|
1118
|
-
def set_ms_function_compile_status(self, status):
|
|
1184
|
+
def set_ms_function_compile_status(self, status, phase):
|
|
1119
1185
|
"""
|
|
1120
1186
|
Set ms_function is compiling
|
|
1121
1187
|
|
|
1122
1188
|
Args:
|
|
1123
1189
|
status(bool): ms_function compile status
|
|
1190
|
+
phase (str): The phase of cell/function instance.
|
|
1124
1191
|
Return:
|
|
1125
1192
|
None.
|
|
1126
1193
|
"""
|
|
1127
|
-
self._executor.set_ms_function_compile_status(status)
|
|
1194
|
+
self._executor.set_ms_function_compile_status(status, phase)
|
|
1128
1195
|
|
|
1129
|
-
def set_dynamic_input(self, obj
|
|
1196
|
+
def set_dynamic_input(self, obj):
|
|
1130
1197
|
"""
|
|
1131
1198
|
Set dynamic shape tensor of input arguments.
|
|
1132
1199
|
|
|
1133
1200
|
Args:
|
|
1134
1201
|
obj (Function/Cell): The function or cell instance.
|
|
1135
|
-
args (tuple): Function or cell dynamic input arguments.
|
|
1136
1202
|
|
|
1137
1203
|
Return:
|
|
1138
1204
|
None.
|
|
1139
1205
|
"""
|
|
1140
|
-
self._executor.set_dynamic_input(obj
|
|
1206
|
+
self._executor.set_dynamic_input(obj)
|
|
1141
1207
|
|
|
1142
1208
|
def is_first_cell(self):
|
|
1143
1209
|
"""
|
|
@@ -1161,14 +1227,6 @@ class _PyNativeExecutor:
|
|
|
1161
1227
|
"""
|
|
1162
1228
|
self._executor.set_hook_changed(cell)
|
|
1163
1229
|
|
|
1164
|
-
def get_optimizer(self):
|
|
1165
|
-
"""
|
|
1166
|
-
Get the optimizer.
|
|
1167
|
-
|
|
1168
|
-
Return:
|
|
1169
|
-
The optimizer.
|
|
1170
|
-
"""
|
|
1171
|
-
return self._optimizer
|
|
1172
1230
|
|
|
1173
1231
|
def get_top_cell(self):
|
|
1174
1232
|
"""
|
|
@@ -1179,17 +1237,6 @@ class _PyNativeExecutor:
|
|
|
1179
1237
|
"""
|
|
1180
1238
|
return self._top_cell
|
|
1181
1239
|
|
|
1182
|
-
def get_shape(self, *args):
|
|
1183
|
-
"""
|
|
1184
|
-
Get shape of input arguments.
|
|
1185
|
-
|
|
1186
|
-
Args:
|
|
1187
|
-
args (Tensor/tuple(Tensor)): Input arguments.
|
|
1188
|
-
|
|
1189
|
-
Return:
|
|
1190
|
-
tuple(int), the shape of input arguments.
|
|
1191
|
-
"""
|
|
1192
|
-
return self._executor.get_shape(*args)
|
|
1193
1240
|
|
|
1194
1241
|
def constant_folding(self, *args):
|
|
1195
1242
|
"""
|
|
@@ -1291,16 +1338,17 @@ class _CellGraphExecutor:
|
|
|
1291
1338
|
if "train" in phase and (enable_compile_cache is True or enable_compile_cache == "1"):
|
|
1292
1339
|
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
|
1293
1340
|
|
|
1294
|
-
def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None):
|
|
1341
|
+
def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None, **kwargs):
|
|
1295
1342
|
"""
|
|
1296
1343
|
Compiles graph.
|
|
1297
1344
|
|
|
1298
1345
|
Args:
|
|
1299
1346
|
obj (Function/Cell): The function or cell instance need compile.
|
|
1300
|
-
args (tuple): Function or cell input arguments.
|
|
1301
1347
|
phase (str): The name of compile phase. Default: 'predict'.
|
|
1302
1348
|
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
|
|
1303
1349
|
jit_config_dict (dict): Jit config for compile. Default: None.
|
|
1350
|
+
args (tuple): Args of the Cell object.
|
|
1351
|
+
kwargs (dict): Kwargs of the Cell object.
|
|
1304
1352
|
|
|
1305
1353
|
Return:
|
|
1306
1354
|
Str, the full phase of the cell.
|
|
@@ -1310,14 +1358,13 @@ class _CellGraphExecutor:
|
|
|
1310
1358
|
if not hasattr(obj, obj.__parse_method__):
|
|
1311
1359
|
raise AttributeError(
|
|
1312
1360
|
'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
|
1313
|
-
args_list = args
|
|
1314
1361
|
|
|
1315
1362
|
self.enable_tuple_broaden = False
|
|
1316
1363
|
if hasattr(obj, "enable_tuple_broaden"):
|
|
1317
1364
|
self.enable_tuple_broaden = obj.enable_tuple_broaden
|
|
1318
1365
|
|
|
1319
1366
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
1320
|
-
key = self._graph_executor.generate_arguments_key(
|
|
1367
|
+
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
1321
1368
|
obj.arguments_key = str(key)
|
|
1322
1369
|
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
1323
1370
|
|
|
@@ -1327,14 +1374,14 @@ class _CellGraphExecutor:
|
|
|
1327
1374
|
|
|
1328
1375
|
obj.check_names()
|
|
1329
1376
|
_check_full_batch()
|
|
1330
|
-
self._set_dataset_mode(
|
|
1377
|
+
self._set_dataset_mode(args)
|
|
1331
1378
|
self._set_compile_cache_dep_files(phase)
|
|
1332
1379
|
|
|
1333
1380
|
enable_ge = context.get_context("enable_ge")
|
|
1334
1381
|
self._graph_executor.set_weights_values(obj.parameters_dict())
|
|
1335
1382
|
if jit_config_dict:
|
|
1336
1383
|
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1337
|
-
result = self._graph_executor.compile(obj,
|
|
1384
|
+
result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
|
|
1338
1385
|
obj.compile_cache.add(phase)
|
|
1339
1386
|
if not result:
|
|
1340
1387
|
raise RuntimeError("Executor compile failed.")
|
|
@@ -1411,6 +1458,8 @@ class _CellGraphExecutor:
|
|
|
1411
1458
|
Run the specific graph.
|
|
1412
1459
|
|
|
1413
1460
|
Args:
|
|
1461
|
+
obj (Cell): The cell object.
|
|
1462
|
+
args (tuple): Args of the Cell object.
|
|
1414
1463
|
phase (str): The phase name. Default: 'predict'.
|
|
1415
1464
|
|
|
1416
1465
|
Returns:
|
|
@@ -1425,31 +1474,33 @@ class _CellGraphExecutor:
|
|
|
1425
1474
|
return self._exec_pip(obj, *args, phase=phase_real)
|
|
1426
1475
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
|
1427
1476
|
|
|
1428
|
-
def del_net_res(self, net_id):
|
|
1429
|
-
|
|
1477
|
+
def del_net_res(self, obj, net_id):
|
|
1478
|
+
"""Clear the memory resource of a network."""
|
|
1479
|
+
self._graph_executor.del_net_res(obj, net_id)
|
|
1480
|
+
|
|
1481
|
+
def _get_branch_control_input(self):
|
|
1482
|
+
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
|
1483
|
+
'obf_random_seed' not in self.obfuscate_config.keys()):
|
|
1484
|
+
raise ValueError("'obf_ratio' and 'obf_random_seed' must be in obfuscate_config.")
|
|
1485
|
+
obf_random_seed = self.obfuscate_config.get('obf_random_seed')
|
|
1486
|
+
if obf_random_seed == 0:
|
|
1487
|
+
branch_control_input = 0
|
|
1488
|
+
else:
|
|
1489
|
+
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
1490
|
+
return branch_control_input
|
|
1430
1491
|
|
|
1431
|
-
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False):
|
|
1492
|
+
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
1432
1493
|
"""Get graph proto from pipeline."""
|
|
1433
1494
|
if use_prefix:
|
|
1434
1495
|
exec_id = exec_id + '.' + obj.arguments_key
|
|
1435
1496
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
1436
1497
|
return None
|
|
1437
1498
|
if self.obfuscate_config is not None:
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
append_password = 0
|
|
1444
|
-
else:
|
|
1445
|
-
seed_max = 2 ** 32 - 1
|
|
1446
|
-
int_max = 2 ** 31 - 1
|
|
1447
|
-
np.random.seed(obf_password % seed_max)
|
|
1448
|
-
append_password = np.random.randint(int_max)
|
|
1449
|
-
obf_password %= int_max
|
|
1450
|
-
return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, self.obfuscate_config['obf_ratio'],
|
|
1451
|
-
obf_password, append_password)
|
|
1452
|
-
return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
|
|
1499
|
+
branch_control_input = self._get_branch_control_input()
|
|
1500
|
+
return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, incremental,
|
|
1501
|
+
self.obfuscate_config['obf_ratio'],
|
|
1502
|
+
branch_control_input)
|
|
1503
|
+
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
1453
1504
|
|
|
1454
1505
|
def get_optimize_graph_proto(self, obj):
|
|
1455
1506
|
"""Return optimize graph binary proto."""
|
|
@@ -1472,12 +1523,6 @@ class _CellGraphExecutor:
|
|
|
1472
1523
|
"""
|
|
1473
1524
|
self._graph_executor.export_graph(file_name, graph_id, encrypt_func, enc_key)
|
|
1474
1525
|
|
|
1475
|
-
def fetch_info_for_quant_export(self, exec_id):
|
|
1476
|
-
"""Get graph proto from pipeline."""
|
|
1477
|
-
if self._graph_executor.has_compiled(exec_id) is False:
|
|
1478
|
-
return None
|
|
1479
|
-
return self._graph_executor.fetch_info_for_quant_export(exec_id)
|
|
1480
|
-
|
|
1481
1526
|
|
|
1482
1527
|
def ms_memory_recycle():
|
|
1483
1528
|
"""
|
|
@@ -1487,15 +1532,44 @@ def ms_memory_recycle():
|
|
|
1487
1532
|
To recycle these cached memory, users can call this function after training of one model.
|
|
1488
1533
|
"""
|
|
1489
1534
|
if ms_compile_cache:
|
|
1490
|
-
_cell_graph_executor.del_net_res(ms_compile_cache)
|
|
1535
|
+
_cell_graph_executor.del_net_res(None, ms_compile_cache)
|
|
1491
1536
|
ms_compile_cache.clear()
|
|
1492
1537
|
for cell_cache in cells_compile_cache.values():
|
|
1493
1538
|
if cell_cache:
|
|
1494
|
-
_cell_graph_executor.del_net_res(cell_cache)
|
|
1539
|
+
_cell_graph_executor.del_net_res(None, cell_cache)
|
|
1495
1540
|
cell_cache.clear()
|
|
1496
1541
|
_ms_memory_recycle()
|
|
1497
1542
|
|
|
1498
1543
|
|
|
1544
|
+
def _generate_branch_control_input(obf_random_seed):
|
|
1545
|
+
"""Generate append network input for dynamic obfuscation in random seed mode."""
|
|
1546
|
+
seed_max = 2 ** 32 - 1
|
|
1547
|
+
int_max = 2 ** 31 - 1
|
|
1548
|
+
np.random.seed(obf_random_seed % seed_max)
|
|
1549
|
+
# generate a string as hash function inputs
|
|
1550
|
+
word_repo = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghigklmnopqrstuvwxyz" + "0123456789"
|
|
1551
|
+
repo_len = len(word_repo)
|
|
1552
|
+
sha_string = ''
|
|
1553
|
+
string_len = 1024 * 1024
|
|
1554
|
+
for _ in range(string_len):
|
|
1555
|
+
rand_index = np.random.randint(0, repo_len)
|
|
1556
|
+
sha_string += word_repo[rand_index]
|
|
1557
|
+
# get hash result
|
|
1558
|
+
sha_result = hashlib.sha256(sha_string.encode('utf-8')).hexdigest() # len is 64
|
|
1559
|
+
branch_control_input = 1
|
|
1560
|
+
hex_base = 16
|
|
1561
|
+
for item in sha_result:
|
|
1562
|
+
if int(item, hex_base) > 0:
|
|
1563
|
+
branch_control_input *= int(item, hex_base)
|
|
1564
|
+
branch_control_input %= int_max
|
|
1565
|
+
return branch_control_input
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
def _bind_device_context():
|
|
1569
|
+
"""Bind device context to current thread"""
|
|
1570
|
+
_bind_device_ctx()
|
|
1571
|
+
|
|
1572
|
+
|
|
1499
1573
|
_cell_graph_executor = _CellGraphExecutor()
|
|
1500
1574
|
_pynative_executor = _PyNativeExecutor()
|
|
1501
1575
|
|