mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -26,6 +26,7 @@ import importlib
|
|
|
26
26
|
import platform
|
|
27
27
|
import subprocess
|
|
28
28
|
import numpy as np
|
|
29
|
+
import mindspore as ms
|
|
29
30
|
from mindspore._c_expression import Oplib, typing
|
|
30
31
|
from mindspore import context
|
|
31
32
|
from mindspore.common import Tensor
|
|
@@ -61,6 +62,27 @@ def _get_cache_path():
|
|
|
61
62
|
return cache_path
|
|
62
63
|
|
|
63
64
|
|
|
65
|
+
def _get_cuda_bare_metal_version():
|
|
66
|
+
"""
|
|
67
|
+
Automatically get the cuda version.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
tuple(str), the version of cuda of the platform.ss
|
|
71
|
+
"""
|
|
72
|
+
raw_output = subprocess.check_output(["nvcc", "-V"],
|
|
73
|
+
universal_newlines=True)
|
|
74
|
+
output = raw_output.split()
|
|
75
|
+
release_idx = output.index("release") + 1
|
|
76
|
+
release = output[release_idx].split(".")
|
|
77
|
+
version_major = release[0]
|
|
78
|
+
version_idx = release_idx + 1
|
|
79
|
+
version = output[version_idx].split(".")
|
|
80
|
+
version_middle = version[1] if len(version) > 1 else 0
|
|
81
|
+
version_minor = version[2] if len(version) > 2 else 0
|
|
82
|
+
|
|
83
|
+
return int(version_major), int(version_middle), int(version_minor)
|
|
84
|
+
|
|
85
|
+
|
|
64
86
|
def _compile_aot(file):
|
|
65
87
|
"""
|
|
66
88
|
Automatically compile the source file for custom aot
|
|
@@ -97,25 +119,21 @@ def _compile_aot(file):
|
|
|
97
119
|
Custom.compiled_bin.append(func_path)
|
|
98
120
|
|
|
99
121
|
if file.endswith("cpp") or file.endswith("cc"):
|
|
100
|
-
cmd = ["g++", "-std=c++17", "--shared", "-fPIC"]
|
|
122
|
+
cmd = ["g++", "-std=c++17", "--shared", "-fPIC", "-D_GLIBCXX_USE_CXX11_ABI=0"]
|
|
101
123
|
cmd += [include_file, "-o", func_path, file]
|
|
102
124
|
elif file.endswith("cu"):
|
|
103
125
|
cmd = ["nvcc"]
|
|
104
126
|
cmd += ["--shared", "-Xcompiler", "-fPIC", "-O3", "-gencode", "arch=compute_70, code=sm_70"]
|
|
105
|
-
cmd += ["--use_fast_math", "--expt-relaxed-constexpr"
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
return int(version_major)
|
|
116
|
-
|
|
117
|
-
if _get_cuda_bare_metal_version() >= 11:
|
|
118
|
-
cmd += ["-gencode", "arch=compute_80,code=sm_80"]
|
|
127
|
+
cmd += ["--use_fast_math", "--expt-relaxed-constexpr"]
|
|
128
|
+
cmd += ["-D_GLIBCXX_USE_CXX11_ABI=0"]
|
|
129
|
+
|
|
130
|
+
v_major, v_mid, v_minor = _get_cuda_bare_metal_version()
|
|
131
|
+
if v_major >= 11:
|
|
132
|
+
cmd += ["-gencode", "arch=compute_80,code=sm_80", "--expt-extended-lambda"]
|
|
133
|
+
elif v_major == 10 and not(v_mid >= 1 and v_minor >= 168):
|
|
134
|
+
logger.warning("The current version of nvcc, V{}.{}.{}, might have unfixed issues with std string, "
|
|
135
|
+
"which will lead to errors in aot custom op with attrs."
|
|
136
|
+
"The version higher than V10.1.168 is recommended".format(v_major, v_mid, v_minor))
|
|
119
137
|
cmd += [include_file, "-o", func_path, file]
|
|
120
138
|
else:
|
|
121
139
|
raise ValueError("The source file must be a cc/cpp/cu file, but get: {}".format(file))
|
|
@@ -141,10 +159,10 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
141
159
|
function if needed. Then these `Custom` objects can be directly used in neural networks.
|
|
142
160
|
Detailed description and introduction of user-defined operators, including correct writing of parameters,
|
|
143
161
|
please refer to `Custom Operators Tutorial
|
|
144
|
-
<https://www.mindspore.cn/tutorials/experts/en/
|
|
162
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/operation/op_custom.html>`_ .
|
|
145
163
|
|
|
146
164
|
.. warning::
|
|
147
|
-
This is an experimental
|
|
165
|
+
This is an experimental API that is subject to change.
|
|
148
166
|
|
|
149
167
|
.. note::
|
|
150
168
|
The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
|
|
@@ -166,7 +184,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
166
184
|
1. A AKG operator implementation function, which can use ir builder/tvm compute/hybrid grammar.
|
|
167
185
|
2. A TBE operator implementation function.
|
|
168
186
|
3. A pure python function
|
|
169
|
-
4. An
|
|
187
|
+
4. An kernel decorated function written by the Hybrid DSL.
|
|
170
188
|
|
|
171
189
|
- str: If func is of str type, then str should be a path of file along with a function name.
|
|
172
190
|
This could be used when func_type is "aot" or "julia".
|
|
@@ -317,10 +335,10 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
317
335
|
``Ascend`` ``GPU`` ``CPU``
|
|
318
336
|
|
|
319
337
|
Examples:
|
|
320
|
-
>>> import mindspore.ops as ops
|
|
321
338
|
>>> import numpy as np
|
|
322
|
-
>>> from mindspore
|
|
323
|
-
>>> from mindspore.
|
|
339
|
+
>>> from mindspore import Tensor, ops
|
|
340
|
+
>>> from mindspore.ops import CustomRegOp, custom_info_register, DataType, kernel
|
|
341
|
+
>>> from mindspore import dtype as mstype
|
|
324
342
|
>>> from mindspore.nn import Cell
|
|
325
343
|
>>> input_x = Tensor(np.ones([16, 16]).astype(np.float32))
|
|
326
344
|
>>> input_y = Tensor(np.ones([16, 16]).astype(np.float32))
|
|
@@ -329,8 +347,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
329
347
|
>>> # This is the default func_type in Custom,
|
|
330
348
|
>>> # and both out_shape and out_dtype can be None(default value).
|
|
331
349
|
>>> # In this case, the input func must be a function written in the Hybrid DSL
|
|
332
|
-
>>> # and decorated by @
|
|
333
|
-
>>> @
|
|
350
|
+
>>> # and decorated by @kernel.
|
|
351
|
+
>>> @kernel
|
|
334
352
|
... def add_script(a, b):
|
|
335
353
|
... c = output_tensor(a.shape, a.dtype)
|
|
336
354
|
... for i0 in range(a.shape[0]):
|
|
@@ -436,6 +454,9 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
436
454
|
registered_func = {}
|
|
437
455
|
attr_dict = {} # Save input_names and attr_names for func.
|
|
438
456
|
compiled_bin = [] # Save names for compiled bin.
|
|
457
|
+
tbe_path_checked = [] # Save paths for tbe functions which is safe to be imported as module.
|
|
458
|
+
tbe_path_failed = [] # Save paths for tbe functions which fail to be imported as module.
|
|
459
|
+
op_path_in_cache = [] # Save paths for op functions created in the cached.
|
|
439
460
|
|
|
440
461
|
def __init__(self, func, out_shape=None, out_dtype=None, func_type="hybrid", bprop=None, reg_info=None):
|
|
441
462
|
ops.PrimitiveWithInfer.__init__(self, "Custom")
|
|
@@ -453,7 +474,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
453
474
|
self._is_ms_kernel = False
|
|
454
475
|
|
|
455
476
|
self._check_func()
|
|
456
|
-
self._update_func_info()
|
|
477
|
+
self._update_func_info(reg_info)
|
|
457
478
|
self.add_prim_attr("func_name", self.func_name)
|
|
458
479
|
self.add_prim_attr("uniq_name", self.uniq_name)
|
|
459
480
|
if self.func_type == "hybrid":
|
|
@@ -468,15 +489,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
468
489
|
self.out_shape = out_shape
|
|
469
490
|
self.out_dtype = out_dtype
|
|
470
491
|
self.bprop = bprop
|
|
471
|
-
self.
|
|
472
|
-
self.single_scalar_output = False
|
|
473
|
-
if not self.out_dtype:
|
|
474
|
-
self.fake_output = True
|
|
475
|
-
elif not self.out_shape:
|
|
476
|
-
self.single_scalar_output = True
|
|
477
|
-
self.add_prim_attr("fake_output", self.fake_output)
|
|
478
|
-
self.add_prim_attr("single_scalar_output", self.single_scalar_output)
|
|
479
|
-
|
|
492
|
+
self._update_op_attr()
|
|
480
493
|
# Register info
|
|
481
494
|
self._register_info(reg_info)
|
|
482
495
|
|
|
@@ -497,6 +510,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
497
510
|
self._update_attr()
|
|
498
511
|
|
|
499
512
|
def __infer__(self, *args):
|
|
513
|
+
"""Infer function of the custom op"""
|
|
500
514
|
if callable(self.out_shape):
|
|
501
515
|
infer_shape = self.out_shape(*(x["shape"] for x in args))
|
|
502
516
|
else:
|
|
@@ -533,7 +547,12 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
533
547
|
logger.warning("{}, 'out_dtype' is an empty tuple. Add a placeholder instead. "
|
|
534
548
|
"Not recommend to use it as it could be any uninitialized data.".format(self.log_prefix))
|
|
535
549
|
infer_dtype = mstype.int32
|
|
536
|
-
|
|
550
|
+
if self.func_type == "aot":
|
|
551
|
+
if infer_shape is None:
|
|
552
|
+
logger.warning("{}, 'out_shape' is None. Add a placeholder instead. "
|
|
553
|
+
"A CPP version of infer shape function is required "
|
|
554
|
+
"in this case.".format(self.log_prefix))
|
|
555
|
+
infer_shape = (1,)
|
|
537
556
|
# after all automatic infer information fulfillment, throw error if infer_shape/infer_dtype is still None
|
|
538
557
|
if not isinstance(infer_shape, (tuple, list)):
|
|
539
558
|
raise TypeError("{}, 'out_shape' must be one of [tuple, list, function], but got {}"
|
|
@@ -551,8 +570,22 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
551
570
|
return out
|
|
552
571
|
|
|
553
572
|
def get_bprop(self):
|
|
573
|
+
"""Get the bprop of the custom op"""
|
|
554
574
|
return self.bprop
|
|
555
575
|
|
|
576
|
+
def _update_op_attr(self):
|
|
577
|
+
"""Update the attrs of the custom op"""
|
|
578
|
+
if self.out_shape is None and self.func_type == "aot":
|
|
579
|
+
self.add_prim_attr("cpp_infer_shape", True)
|
|
580
|
+
self.fake_output = False
|
|
581
|
+
self.single_scalar_output = False
|
|
582
|
+
if not self.out_dtype:
|
|
583
|
+
self.fake_output = True
|
|
584
|
+
elif not self.out_shape:
|
|
585
|
+
self.single_scalar_output = True
|
|
586
|
+
self.add_prim_attr("fake_output", self.fake_output)
|
|
587
|
+
self.add_prim_attr("single_scalar_output", self.single_scalar_output)
|
|
588
|
+
|
|
556
589
|
def _check_julia_func(self):
|
|
557
590
|
"""Check the validity of julia func"""
|
|
558
591
|
if not isinstance(self.func, str):
|
|
@@ -592,17 +625,17 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
592
625
|
self._check_julia_func()
|
|
593
626
|
elif self.func_type == "hybrid":
|
|
594
627
|
if not hasattr(self.func, "ms_kernel_flag"):
|
|
595
|
-
raise TypeError("{}, 'func' must be a function decorated by
|
|
628
|
+
raise TypeError("{}, 'func' must be a function decorated by kernel".format(self.log_prefix))
|
|
596
629
|
self._is_ms_kernel = True
|
|
597
630
|
self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
|
|
598
631
|
elif self.func_type == "akg":
|
|
599
632
|
if hasattr(self.func, "ms_kernel_flag"):
|
|
600
633
|
logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
|
|
601
|
-
"for the input function with decorator @
|
|
634
|
+
"for the input function with decorator @kernel. "
|
|
602
635
|
"To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
|
|
603
636
|
elif self.func_type == "pyfunc":
|
|
604
637
|
if hasattr(self.func, "ms_kernel_flag"):
|
|
605
|
-
logger.warning("{}. Now you are using the function with decorator @
|
|
638
|
+
logger.warning("{}. Now you are using the function with decorator @kernel in the mode pyfunc. "
|
|
606
639
|
"The kernel will be executed as a native python function, which might lead to "
|
|
607
640
|
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
|
|
608
641
|
.format(self.log_prefix))
|
|
@@ -611,7 +644,64 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
611
644
|
raise TypeError("{}, 'func' must be of type function, but got {}"
|
|
612
645
|
.format(self.log_prefix, type(self.func)))
|
|
613
646
|
|
|
614
|
-
def
|
|
647
|
+
def _update_func_imply_path(self):
|
|
648
|
+
"""Update op_imply_path of func"""
|
|
649
|
+
file_path = os.path.realpath(inspect.getfile(self.func))
|
|
650
|
+
|
|
651
|
+
if not self.func_type == "tbe":
|
|
652
|
+
# Custom ops with type other than tbe doesn't need to import from the path
|
|
653
|
+
# use the file path directly
|
|
654
|
+
return file_path
|
|
655
|
+
# For the custom op of type tbe, the kernel compiler will import the module from file path.
|
|
656
|
+
# we will try import in the initialization,
|
|
657
|
+
if file_path in Custom.tbe_path_checked:
|
|
658
|
+
logger.info("The file of {} has already been checked good to be imported.".format(self.func_name))
|
|
659
|
+
return file_path
|
|
660
|
+
|
|
661
|
+
if file_path not in Custom.tbe_path_failed:
|
|
662
|
+
# As a single file might include multiply functions
|
|
663
|
+
# we will not try the file path which already failed in previous trials
|
|
664
|
+
mod_spec = importlib.util.spec_from_file_location(
|
|
665
|
+
self.func_name, file_path)
|
|
666
|
+
custom_mod = importlib.util.module_from_spec(mod_spec)
|
|
667
|
+
try:
|
|
668
|
+
mod_spec.loader.exec_module(custom_mod)
|
|
669
|
+
except (ImportError, RecursionError):
|
|
670
|
+
Custom.tbe_path_failed.append(file_path)
|
|
671
|
+
else:
|
|
672
|
+
Custom.tbe_path_checked.append(file_path)
|
|
673
|
+
return file_path
|
|
674
|
+
|
|
675
|
+
# Create a new file for each tbe function
|
|
676
|
+
op_imply_path = os.path.realpath(_get_cache_path() + self.func_name + ".py")
|
|
677
|
+
if op_imply_path in Custom.op_path_in_cache:
|
|
678
|
+
logger.info("The new file of {} has already been created.".format(self.func_name))
|
|
679
|
+
return op_imply_path
|
|
680
|
+
|
|
681
|
+
logger.warning("Fail to import the original source file. Create a new source file for {}. "
|
|
682
|
+
"The new file will not include the dependency for the op function. "
|
|
683
|
+
"Check the definition of the function {} "
|
|
684
|
+
"in the file: {}".format(self.func_name, self.func_name, op_imply_path))
|
|
685
|
+
|
|
686
|
+
Custom.op_path_in_cache.append(op_imply_path)
|
|
687
|
+
|
|
688
|
+
if os.path.exists(op_imply_path):
|
|
689
|
+
try:
|
|
690
|
+
os.remove(op_imply_path)
|
|
691
|
+
except FileNotFoundError:
|
|
692
|
+
logger.warning("Fail to remove the existing file. Check the definition of the function {} "
|
|
693
|
+
"in the file: {}".format(self.func_name, op_imply_path))
|
|
694
|
+
|
|
695
|
+
with open(op_imply_path, 'at') as file:
|
|
696
|
+
if platform.system() != "Windows":
|
|
697
|
+
fcntl.flock(file.fileno(), fcntl.LOCK_EX)
|
|
698
|
+
file.seek(0, 2)
|
|
699
|
+
if file.tell() == 0:
|
|
700
|
+
file.write(self.func_source_str)
|
|
701
|
+
os.chmod(op_imply_path, stat.S_IRUSR | stat.S_IWUSR)
|
|
702
|
+
return op_imply_path
|
|
703
|
+
|
|
704
|
+
def _update_func_info(self, reg_info):
|
|
615
705
|
"""Update information of func"""
|
|
616
706
|
if callable(self.func):
|
|
617
707
|
# For the func_type other then hybrid, get the original function if func is decorated
|
|
@@ -626,19 +716,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
626
716
|
if index != -1:
|
|
627
717
|
self.func_source_str = self.func_source_str[index:]
|
|
628
718
|
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
os.remove(op_imply_path)
|
|
632
|
-
with open(op_imply_path, 'at') as file:
|
|
633
|
-
if platform.system() != "Windows":
|
|
634
|
-
fcntl.flock(file.fileno(), fcntl.LOCK_EX)
|
|
635
|
-
file.seek(0, 2)
|
|
636
|
-
if file.tell() == 0:
|
|
637
|
-
file.write(self.func_source_str)
|
|
638
|
-
os.chmod(op_imply_path, stat.S_IRUSR | stat.S_IWUSR)
|
|
639
|
-
|
|
640
|
-
# path of func
|
|
641
|
-
self.imply_path = op_imply_path
|
|
719
|
+
# update path of func for TBE type of custom op
|
|
720
|
+
self.imply_path = self._update_func_imply_path()
|
|
642
721
|
if self._is_ms_kernel:
|
|
643
722
|
# static check for the Hybrid DSL in hybrid
|
|
644
723
|
root = ast.parse(self.func_source_str)
|
|
@@ -660,11 +739,38 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
660
739
|
# func name
|
|
661
740
|
self.func_name = self.func
|
|
662
741
|
# uniq func name
|
|
663
|
-
|
|
742
|
+
prefix = self.name
|
|
743
|
+
if reg_info is None:
|
|
744
|
+
reg_info = {}
|
|
745
|
+
reg_info_list = self._get_expanded_list(reg_info)
|
|
746
|
+
for reg_info_item in reg_info_list:
|
|
747
|
+
if not isinstance(reg_info_item, (str, dict)):
|
|
748
|
+
continue
|
|
749
|
+
if isinstance(reg_info_item, str):
|
|
750
|
+
reg_info_item = json.loads(reg_info_item)
|
|
751
|
+
prefix = "_".join([prefix, reg_info_item.get("op_name", "")])
|
|
752
|
+
self.uniq_name = prefix + "_" + self.func_name
|
|
664
753
|
else:
|
|
665
754
|
raise TypeError("For '{}', 'func' must be of type function or str, but got {}"
|
|
666
755
|
.format(self.name, type(self.func)))
|
|
667
756
|
|
|
757
|
+
def _update_reg_attrs(self, reg_info):
|
|
758
|
+
"""Update op attrs in reg_info."""
|
|
759
|
+
for _, item in enumerate(reg_info.get("outputs", [])):
|
|
760
|
+
output_name_list = []
|
|
761
|
+
if isinstance(item, dict) and item.get("name"):
|
|
762
|
+
output_name_list.append(item.get("name"))
|
|
763
|
+
self.add_prim_attr("output_names", output_name_list)
|
|
764
|
+
|
|
765
|
+
if isinstance(reg_info.get("op_name"), str):
|
|
766
|
+
self.add_prim_attr("reg_op_name", reg_info.get("op_name"))
|
|
767
|
+
|
|
768
|
+
if self.func_type == "aot":
|
|
769
|
+
if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
|
|
770
|
+
for item in reg_info["attr"]:
|
|
771
|
+
if isinstance(item, dict) and item.get("value") is not None:
|
|
772
|
+
self.add_prim_attr(item["name"], item["value"])
|
|
773
|
+
|
|
668
774
|
def _register_info(self, info):
|
|
669
775
|
"""Register reg_info."""
|
|
670
776
|
reg_info = info
|
|
@@ -687,18 +793,11 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
687
793
|
new_dtype_format.append(i + (DataType.I32_Default,))
|
|
688
794
|
reg_info["dtype_format"] = new_dtype_format
|
|
689
795
|
|
|
690
|
-
|
|
691
|
-
output_name_list = []
|
|
692
|
-
if isinstance(item, dict) and item.get("name"):
|
|
693
|
-
output_name_list.append(item.get("name"))
|
|
694
|
-
self.add_prim_attr("output_names", output_name_list)
|
|
695
|
-
|
|
696
|
-
if isinstance(reg_info.get("op_name"), str):
|
|
697
|
-
self.add_prim_attr("reg_op_name", reg_info.get("op_name"))
|
|
796
|
+
self._update_reg_attrs(reg_info)
|
|
698
797
|
|
|
699
798
|
target = self._get_target(reg_info)
|
|
700
799
|
# Reg info for func is only registered once for a certain target
|
|
701
|
-
if self._has_registered(target
|
|
800
|
+
if self._has_registered(target):
|
|
702
801
|
continue
|
|
703
802
|
# Register
|
|
704
803
|
reg_info = self._reformat_reg_info(reg_info, target)
|
|
@@ -710,7 +809,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
710
809
|
"'custom_info_register' to bind it to 'func' if 'func' is a function."
|
|
711
810
|
.format(self.log_prefix))
|
|
712
811
|
self._save_attr(reg_info)
|
|
713
|
-
self._save_register_status(target
|
|
812
|
+
self._save_register_status(target)
|
|
714
813
|
|
|
715
814
|
def _get_expanded_list(self, data):
|
|
716
815
|
"""Recursive function to parse elements in list or tuple."""
|
|
@@ -724,43 +823,34 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
724
823
|
data_list.append(data)
|
|
725
824
|
return data_list
|
|
726
825
|
|
|
727
|
-
def _get_registered_targets(self
|
|
826
|
+
def _get_registered_targets(self):
|
|
728
827
|
"""Get the registered targets of func."""
|
|
729
828
|
targets = []
|
|
730
|
-
if reg_info is None:
|
|
731
|
-
reg_info = {}
|
|
732
829
|
if callable(self.func):
|
|
733
830
|
targets = getattr(self.func, "registered_targets", [])
|
|
734
831
|
elif isinstance(self.func, str):
|
|
735
|
-
|
|
736
|
-
reg_op_name = reg_info.get("op_name")
|
|
737
|
-
else:
|
|
738
|
-
reg_op_name = self.func
|
|
739
|
-
targets = Custom.registered_func.get(reg_op_name, [])
|
|
832
|
+
targets = Custom.registered_func.get(self.uniq_name, [])
|
|
740
833
|
if not isinstance(targets, list):
|
|
741
834
|
targets = [targets]
|
|
742
835
|
return targets
|
|
743
836
|
|
|
744
|
-
def _has_registered(self, target
|
|
837
|
+
def _has_registered(self, target):
|
|
745
838
|
"""Check if registration information is registered in target."""
|
|
746
|
-
registered_targets = self._get_registered_targets(
|
|
839
|
+
registered_targets = self._get_registered_targets()
|
|
747
840
|
return target in registered_targets
|
|
748
841
|
|
|
749
|
-
def _save_register_status(self, target
|
|
842
|
+
def _save_register_status(self, target):
|
|
750
843
|
"""Save registration status for target."""
|
|
751
844
|
if callable(self.func):
|
|
752
845
|
registered_targets = getattr(self.func, "registered_targets", [])
|
|
753
846
|
registered_targets.append(target)
|
|
754
847
|
setattr(self.func, "registered_targets", registered_targets)
|
|
755
848
|
elif isinstance(self.func, str):
|
|
756
|
-
|
|
757
|
-
|
|
849
|
+
func_name = self.uniq_name
|
|
850
|
+
if isinstance(Custom.registered_func.get(func_name), list):
|
|
851
|
+
Custom.registered_func.get(func_name).append(target)
|
|
758
852
|
else:
|
|
759
|
-
|
|
760
|
-
if isinstance(Custom.registered_func.get(reg_op_name), list):
|
|
761
|
-
Custom.registered_func.get(reg_op_name).append(target)
|
|
762
|
-
else:
|
|
763
|
-
Custom.registered_func[reg_op_name] = [target]
|
|
853
|
+
Custom.registered_func[func_name] = [target]
|
|
764
854
|
|
|
765
855
|
def _get_op_name(self, reg_info):
|
|
766
856
|
if self.func_type == "aicpu":
|
|
@@ -786,20 +876,15 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
786
876
|
if isinstance(item, dict) and item.get("value") is None:
|
|
787
877
|
reg_info["attr"][i]["value"] = "all"
|
|
788
878
|
reg_info["async_flag"] = reg_info.get("async_flag", False)
|
|
789
|
-
reg_info["
|
|
879
|
+
reg_info["binfile"] = "%s.so" % self.func_name
|
|
790
880
|
reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
|
|
791
|
-
reg_info["
|
|
881
|
+
reg_info["kernel"] = self.func_name
|
|
792
882
|
reg_info["partial_flag"] = reg_info.get("partial_flag", True)
|
|
793
|
-
reg_info["
|
|
883
|
+
reg_info["needCheckSupport"] = reg_info.get("need_check_supported", False)
|
|
794
884
|
# Supplement necessary info for AKG if these information is missing in reg_info
|
|
795
885
|
if reg_info["imply_type"] == "AKG":
|
|
796
886
|
target_to_processor = {"Ascend": "AiCore", "GPU": "CUDA", "CPU": "CPU"}
|
|
797
887
|
reg_info["processor"] = reg_info.get("processor", target_to_processor.get(target))
|
|
798
|
-
if self.func_type == "aot":
|
|
799
|
-
if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
|
|
800
|
-
for item in reg_info["attr"]:
|
|
801
|
-
if isinstance(item, dict) and item.get("value") is not None:
|
|
802
|
-
self.add_prim_attr(item["name"], item["value"])
|
|
803
888
|
return reg_info
|
|
804
889
|
|
|
805
890
|
def _get_target(self, reg_info):
|
|
@@ -833,8 +918,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
833
918
|
reg_info["imply_type"].strip():
|
|
834
919
|
return reg_info["imply_type"]
|
|
835
920
|
# Infer imply_type from func_type
|
|
836
|
-
func_type_to_imply_type = {"hybrid": "AKG", "akg": "AKG", "tbe": "TBE", "aicpu": "AiCPU", "
|
|
837
|
-
"
|
|
921
|
+
func_type_to_imply_type = {"hybrid": "AKG", "akg": "AKG", "tbe": "TBE", "aicpu": "AiCPU", "pyfunc": target,
|
|
922
|
+
"julia": target, "aot": "BiSheng" if target == "Ascend" else target}
|
|
838
923
|
return func_type_to_imply_type.get(self.func_type, "AKG")
|
|
839
924
|
|
|
840
925
|
def _save_attr(self, reg_info):
|
|
@@ -854,9 +939,11 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
854
939
|
for item in tensor_inputs:
|
|
855
940
|
if isinstance(item, dict) and item.get("name") is not None:
|
|
856
941
|
input_names.append(item["name"])
|
|
942
|
+
has_input_name = bool(input_names)
|
|
857
943
|
for item in attr:
|
|
858
944
|
if isinstance(item, dict) and item.get("name") is not None:
|
|
859
|
-
|
|
945
|
+
if has_input_name or context.get_context("mode") != ms.PYNATIVE_MODE:
|
|
946
|
+
input_names.append(item["name"])
|
|
860
947
|
attr_names.append(item["name"])
|
|
861
948
|
cur_attr = {"input_names": input_names, "attr_names": attr_names}
|
|
862
949
|
# If func does not have attr, save current attr.
|
|
@@ -882,7 +969,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
882
969
|
"""Add primitive_target to primitive's attr."""
|
|
883
970
|
registered_targets = self._get_registered_targets()
|
|
884
971
|
if self.func_type == "pyfunc":
|
|
885
|
-
self.
|
|
972
|
+
self.set_device("CPU")
|
|
886
973
|
if registered_targets and registered_targets != ["CPU"]:
|
|
887
974
|
logger.warning("{}, only supports CPU platform, but got registered target {}. "
|
|
888
975
|
"We will run it on CPU".format(self.log_prefix, registered_targets))
|
|
@@ -890,11 +977,11 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
890
977
|
if len(registered_targets) != 1:
|
|
891
978
|
logger.info("{}, target will be set according to context.".format(self.log_prefix))
|
|
892
979
|
elif registered_targets == ["GPU"]:
|
|
893
|
-
self.
|
|
980
|
+
self.set_device("GPU")
|
|
894
981
|
elif registered_targets == ["CPU"]:
|
|
895
|
-
self.
|
|
982
|
+
self.set_device("CPU")
|
|
896
983
|
elif self.func_type == "julia":
|
|
897
|
-
self.
|
|
984
|
+
self.set_device("CPU")
|
|
898
985
|
device_target = context.get_context('device_target')
|
|
899
986
|
if device_target == "CPU":
|
|
900
987
|
pass
|
|
@@ -962,12 +1049,12 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
962
1049
|
for idx, val in enumerate(symbols)]
|
|
963
1050
|
|
|
964
1051
|
if any(i[1] != -1 for i in inplace_assign_output):
|
|
965
|
-
|
|
966
|
-
|
|
1052
|
+
self.add_prim_attr("inplace_assign_output", " ".join(
|
|
1053
|
+
(str(j) for i in inplace_assign_output for j in i)))
|
|
967
1054
|
|
|
968
1055
|
def _auto_infer(self, *args):
|
|
969
1056
|
"""
|
|
970
|
-
the automatic infer function for functions with @
|
|
1057
|
+
the automatic infer function for functions with @kernel decorator
|
|
971
1058
|
"""
|
|
972
1059
|
fake_input = []
|
|
973
1060
|
enable_infer_value = True
|