mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -62,27 +62,6 @@ def _get_cache_path():
|
|
|
62
62
|
return cache_path
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def _get_cuda_bare_metal_version():
|
|
66
|
-
"""
|
|
67
|
-
Automatically get the cuda version.
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
tuple(str), the version of cuda of the platform.ss
|
|
71
|
-
"""
|
|
72
|
-
raw_output = subprocess.check_output(["nvcc", "-V"],
|
|
73
|
-
universal_newlines=True)
|
|
74
|
-
output = raw_output.split()
|
|
75
|
-
release_idx = output.index("release") + 1
|
|
76
|
-
release = output[release_idx].split(".")
|
|
77
|
-
version_major = release[0]
|
|
78
|
-
version_idx = release_idx + 1
|
|
79
|
-
version = output[version_idx].split(".")
|
|
80
|
-
version_middle = version[1] if len(version) > 1 else 0
|
|
81
|
-
version_minor = version[2] if len(version) > 2 else 0
|
|
82
|
-
|
|
83
|
-
return int(version_major), int(version_middle), int(version_minor)
|
|
84
|
-
|
|
85
|
-
|
|
86
65
|
def _compile_aot(file):
|
|
87
66
|
"""
|
|
88
67
|
Automatically compile the source file for custom aot
|
|
@@ -99,11 +78,7 @@ def _compile_aot(file):
|
|
|
99
78
|
cache_path = os.path.join(cache_path, "rank_" + str(get_rank()), "")
|
|
100
79
|
os.makedirs(cache_path, exist_ok=True)
|
|
101
80
|
|
|
102
|
-
|
|
103
|
-
if search_res is None:
|
|
104
|
-
raise RuntimeError("Cannot find mindspore module!")
|
|
105
|
-
|
|
106
|
-
res_path = search_res.origin
|
|
81
|
+
res_path = importlib.util.find_spec("mindspore").origin
|
|
107
82
|
find_pos = res_path.find("__init__.py")
|
|
108
83
|
if find_pos == -1:
|
|
109
84
|
raise RuntimeError(
|
|
@@ -111,9 +86,8 @@ def _compile_aot(file):
|
|
|
111
86
|
include_file = "-I{}include/api/".format(res_path[:find_pos])
|
|
112
87
|
|
|
113
88
|
file_name = file.split('/')[-1]
|
|
114
|
-
file_folder = file[:file.rindex('/')]
|
|
115
89
|
func_path = cache_path + file_name + ".so"
|
|
116
|
-
include_file = "{} -I{}".format(include_file,
|
|
90
|
+
include_file = "{} -I{}".format(include_file, file[:file.rindex('/')])
|
|
117
91
|
|
|
118
92
|
if func_path not in Custom.compiled_bin:
|
|
119
93
|
Custom.compiled_bin.append(func_path)
|
|
@@ -127,10 +101,23 @@ def _compile_aot(file):
|
|
|
127
101
|
cmd += ["--use_fast_math", "--expt-relaxed-constexpr"]
|
|
128
102
|
cmd += ["-D_GLIBCXX_USE_CXX11_ABI=0"]
|
|
129
103
|
|
|
104
|
+
def _get_cuda_bare_metal_version():
|
|
105
|
+
raw_output = subprocess.check_output(["nvcc", "-V"],
|
|
106
|
+
universal_newlines=True)
|
|
107
|
+
output = raw_output.split()
|
|
108
|
+
release_idx = output.index("release") + 1
|
|
109
|
+
release = output[release_idx].split(".")
|
|
110
|
+
version_idx = release_idx + 1
|
|
111
|
+
version = output[version_idx].split(".")
|
|
112
|
+
version_middle = version[1] if len(version) > 1 else 0
|
|
113
|
+
version_minor = version[2] if len(version) > 2 else 0
|
|
114
|
+
|
|
115
|
+
return int(release[0]), int(version_middle), int(version_minor)
|
|
116
|
+
|
|
130
117
|
v_major, v_mid, v_minor = _get_cuda_bare_metal_version()
|
|
131
118
|
if v_major >= 11:
|
|
132
119
|
cmd += ["-gencode", "arch=compute_80,code=sm_80", "--expt-extended-lambda"]
|
|
133
|
-
elif v_major == 10 and not(v_mid >= 1 and v_minor >= 168):
|
|
120
|
+
elif v_major == 10 and not (v_mid >= 1 and v_minor >= 168):
|
|
134
121
|
logger.warning("The current version of nvcc, V{}.{}.{}, might have unfixed issues with std string, "
|
|
135
122
|
"which will lead to errors in aot custom op with attrs."
|
|
136
123
|
"The version higher than V10.1.168 is recommended".format(v_major, v_mid, v_minor))
|
|
@@ -159,10 +146,11 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
159
146
|
function if needed. Then these `Custom` objects can be directly used in neural networks.
|
|
160
147
|
Detailed description and introduction of user-defined operators, including correct writing of parameters,
|
|
161
148
|
please refer to `Custom Operators Tutorial
|
|
162
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
149
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html>`_ .
|
|
163
150
|
|
|
164
151
|
.. warning::
|
|
165
|
-
This is an experimental API that is subject to change.
|
|
152
|
+
- This is an experimental API that is subject to change.
|
|
153
|
+
- Currently, the functionality of Custom does not support Ascend 910B.
|
|
166
154
|
|
|
167
155
|
.. note::
|
|
168
156
|
The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
|
|
@@ -175,6 +163,12 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
175
163
|
- "julia": supports ["CPU"].
|
|
176
164
|
- "aicpu": supports ["Ascend"].
|
|
177
165
|
|
|
166
|
+
If run on ge backend, use `CustomRegOp` to generate the registration information of "aicpu" and "tbe" operator,
|
|
167
|
+
use `custom_info_register` to bind the registration information to the `func` of the "tbe" operator,
|
|
168
|
+
then save the registration information of "aicpu" operator and the `func` implementation of "tbe" operator to
|
|
169
|
+
a file or separate files, keep these files in a separate directory, and set the absolute path of this directory
|
|
170
|
+
to environment variable "MS_DEV_CUSTOM_OPP_PATH" before running the network.
|
|
171
|
+
|
|
178
172
|
Args:
|
|
179
173
|
func (Union[function, str]):
|
|
180
174
|
|
|
@@ -265,7 +259,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
265
259
|
(ex. Custom(func="./add.jl:Add:add", out_shape=[1], out_dtype=mstype.float32, "julia"))
|
|
266
260
|
|
|
267
261
|
out_shape (Union[function, list, tuple]): The output shape infer function or the value of output shape of
|
|
268
|
-
`func`. Default: None.
|
|
262
|
+
`func`. Default: ``None`` .
|
|
269
263
|
|
|
270
264
|
If func has single output, then the value of output shape is a list or tuple of int.
|
|
271
265
|
|
|
@@ -276,7 +270,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
276
270
|
shape mechanic will be enabled.
|
|
277
271
|
|
|
278
272
|
out_dtype (Union[function, :class:`mindspore.dtype`, tuple[:class:`mindspore.dtype`]]): The output data type
|
|
279
|
-
infer function or the value of output data type of `func`. Default: None.
|
|
273
|
+
infer function or the value of output data type of `func`. Default: ``None`` .
|
|
280
274
|
|
|
281
275
|
If func has single output, then the value of output shape is a `mindspore.dtype`.
|
|
282
276
|
|
|
@@ -288,23 +282,23 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
288
282
|
|
|
289
283
|
func_type (str): The implementation type of `func`, should be one of
|
|
290
284
|
|
|
291
|
-
["hybrid", "akg", "tbe", "aot", "pyfunc", "julia", "aicpu"].
|
|
285
|
+
[ ``"hybrid"`` , ``"akg"`` , ``"tbe"`` , ``"aot"`` , ``"pyfunc"`` , ``"julia"`` , ``"aicpu"`` ].
|
|
292
286
|
|
|
293
|
-
Each `func_type` only supports specific platforms(targets). Default: "hybrid".
|
|
287
|
+
Each `func_type` only supports specific platforms(targets). Default: ``"hybrid"`` .
|
|
294
288
|
The supported platforms of `func_type`:
|
|
295
289
|
|
|
296
|
-
- "hybrid"
|
|
297
|
-
- "akg"
|
|
298
|
-
- "tbe"
|
|
299
|
-
- "aot"
|
|
300
|
-
- "pyfunc"
|
|
301
|
-
- "julia"
|
|
302
|
-
- "aicpu"
|
|
290
|
+
- ``"hybrid"``: supports ["Ascend", "GPU", "CPU"].
|
|
291
|
+
- ``"akg"``: supports ["Ascend", "GPU", "CPU"].
|
|
292
|
+
- ``"tbe"``: supports ["Ascend"].
|
|
293
|
+
- ``"aot"``: supports ["GPU", "CPU"].
|
|
294
|
+
- ``"pyfunc"``: supports ["CPU"].
|
|
295
|
+
- ``"julia"``: supports ["CPU"].
|
|
296
|
+
- ``"aicpu"``: supports ["Ascend"].
|
|
303
297
|
|
|
304
|
-
bprop (function): The back propagation function of `func`. Default: None.
|
|
298
|
+
bprop (function): The back propagation function of `func`. Default: ``None`` .
|
|
305
299
|
reg_info (Union[str, dict, list, tuple]): Represents the registration information(reg info) of `func` with
|
|
306
300
|
json format of type str or dict. The reg info specifies supported data types and formats of inputs and
|
|
307
|
-
outputs, attributes and target of `func`. Default: None.
|
|
301
|
+
outputs, attributes and target of `func`. Default: ``None`` .
|
|
308
302
|
|
|
309
303
|
If reg info is a list or tuple, then each item should be with json format of type str or dict, which
|
|
310
304
|
represents the registration information of `func` in a specific target. You need to invoke `CustomRegOp`
|
|
@@ -457,6 +451,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
457
451
|
tbe_path_checked = [] # Save paths for tbe functions which is safe to be imported as module.
|
|
458
452
|
tbe_path_failed = [] # Save paths for tbe functions which fail to be imported as module.
|
|
459
453
|
op_path_in_cache = [] # Save paths for op functions created in the cached.
|
|
454
|
+
custom_aot_warning = True # Flag to enable warnings about custom aot path white list
|
|
460
455
|
|
|
461
456
|
def __init__(self, func, out_shape=None, out_dtype=None, func_type="hybrid", bprop=None, reg_info=None):
|
|
462
457
|
ops.PrimitiveWithInfer.__init__(self, "Custom")
|
|
@@ -473,6 +468,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
473
468
|
self._func_compile_attrs = {}
|
|
474
469
|
self._is_ms_kernel = False
|
|
475
470
|
|
|
471
|
+
self._check_platform()
|
|
476
472
|
self._check_func()
|
|
477
473
|
self._update_func_info(reg_info)
|
|
478
474
|
self.add_prim_attr("func_name", self.func_name)
|
|
@@ -487,21 +483,24 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
487
483
|
self.add_prim_attr("fn_id", func_id)
|
|
488
484
|
|
|
489
485
|
self.out_shape = out_shape
|
|
486
|
+
if self.out_shape is None and self.func_type == "aot":
|
|
487
|
+
self.add_prim_attr("cpp_infer_shape", True)
|
|
490
488
|
self.out_dtype = out_dtype
|
|
491
489
|
self.bprop = bprop
|
|
492
|
-
self.
|
|
490
|
+
self.fake_output = False
|
|
491
|
+
self.single_scalar_output = False
|
|
492
|
+
if not self.out_dtype:
|
|
493
|
+
self.fake_output = True
|
|
494
|
+
elif not self.out_shape:
|
|
495
|
+
self.single_scalar_output = True
|
|
496
|
+
self.add_prim_attr("fake_output", self.fake_output)
|
|
497
|
+
self.add_prim_attr("single_scalar_output", self.single_scalar_output)
|
|
498
|
+
|
|
493
499
|
# Register info
|
|
494
500
|
self._register_info(reg_info)
|
|
495
501
|
|
|
496
502
|
if func_type == "akg":
|
|
497
|
-
self.
|
|
498
|
-
if "ir_builder" in self.func_source_str:
|
|
499
|
-
self.func_type = "ir_builder"
|
|
500
|
-
elif "compute" in self.func_source_str:
|
|
501
|
-
self.func_type = "tvm_compute"
|
|
502
|
-
else:
|
|
503
|
-
self.func_type = "hybrid"
|
|
504
|
-
self._hybrid_func_analyser()
|
|
503
|
+
self._set_akg_kernel_type()
|
|
505
504
|
|
|
506
505
|
if not self.bprop and self.func_type == "hybrid":
|
|
507
506
|
self._hybrid_autodiff(func_type)
|
|
@@ -510,7 +509,6 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
510
509
|
self._update_attr()
|
|
511
510
|
|
|
512
511
|
def __infer__(self, *args):
|
|
513
|
-
"""Infer function of the custom op"""
|
|
514
512
|
if callable(self.out_shape):
|
|
515
513
|
infer_shape = self.out_shape(*(x["shape"] for x in args))
|
|
516
514
|
else:
|
|
@@ -570,21 +568,17 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
570
568
|
return out
|
|
571
569
|
|
|
572
570
|
def get_bprop(self):
|
|
573
|
-
"""Get the bprop of the custom op"""
|
|
574
571
|
return self.bprop
|
|
575
572
|
|
|
576
|
-
def
|
|
577
|
-
|
|
578
|
-
if
|
|
579
|
-
self.
|
|
580
|
-
self.
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
self.
|
|
584
|
-
|
|
585
|
-
self.single_scalar_output = True
|
|
586
|
-
self.add_prim_attr("fake_output", self.fake_output)
|
|
587
|
-
self.add_prim_attr("single_scalar_output", self.single_scalar_output)
|
|
573
|
+
def _set_akg_kernel_type(self):
|
|
574
|
+
self.add_prim_attr('func_source_str', self.func_source_str)
|
|
575
|
+
if "ir_builder" in self.func_source_str:
|
|
576
|
+
self.func_type = "ir_builder"
|
|
577
|
+
elif "compute" in self.func_source_str:
|
|
578
|
+
self.func_type = "tvm_compute"
|
|
579
|
+
else:
|
|
580
|
+
self.func_type = "hybrid"
|
|
581
|
+
self._hybrid_func_analyser()
|
|
588
582
|
|
|
589
583
|
def _check_julia_func(self):
|
|
590
584
|
"""Check the validity of julia func"""
|
|
@@ -602,6 +596,10 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
602
596
|
raise Exception("{}, function {} is not found in source file {}!"
|
|
603
597
|
.format(self.log_prefix, func, source_file))
|
|
604
598
|
|
|
599
|
+
def _check_platform(self):
|
|
600
|
+
if platform.system() != 'Linux':
|
|
601
|
+
raise Exception("Custom op only supported on Linux platform currently.")
|
|
602
|
+
|
|
605
603
|
def _check_func(self):
|
|
606
604
|
"""Check the validity of func_type and type of func"""
|
|
607
605
|
if self.func_type not in self.supported_func_type:
|
|
@@ -617,7 +615,19 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
617
615
|
"{}, 'func' should be like 'file_name:func_name', but got {}".format(
|
|
618
616
|
self.log_prefix, self.func))
|
|
619
617
|
file_path = os.path.abspath(file_name_list[0])
|
|
620
|
-
if
|
|
618
|
+
if os.environ.get('MS_CUSTOM_AOT_WHITE_LIST') is None:
|
|
619
|
+
if Custom.custom_aot_warning:
|
|
620
|
+
logger.warning("{}, no white list is set and it might cause problems. "
|
|
621
|
+
"Set the legal path of the file in MS_CUSTOM_AOT_WHITE_LIST"
|
|
622
|
+
.format(self.log_prefix))
|
|
623
|
+
Custom.custom_aot_warning = False
|
|
624
|
+
else:
|
|
625
|
+
legal_path = os.path.abspath(os.environ.get('MS_CUSTOM_AOT_WHITE_LIST'))
|
|
626
|
+
if legal_path not in file_path:
|
|
627
|
+
raise TypeError(
|
|
628
|
+
"{}, the legal path for the file is {}, but the file is {}".format(
|
|
629
|
+
self.log_prefix, legal_path, file_path))
|
|
630
|
+
if file_path.endswith(("cu", "cpp", "cc")):
|
|
621
631
|
file_path = _compile_aot(file_path)
|
|
622
632
|
self.func = file_path + ":" + file_name_list[1]
|
|
623
633
|
|
|
@@ -639,7 +649,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
639
649
|
"The kernel will be executed as a native python function, which might lead to "
|
|
640
650
|
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
|
|
641
651
|
.format(self.log_prefix))
|
|
642
|
-
|
|
652
|
+
elif self.func_type == "tbe":
|
|
643
653
|
if not callable(self.func):
|
|
644
654
|
raise TypeError("{}, 'func' must be of type function, but got {}"
|
|
645
655
|
.format(self.log_prefix, type(self.func)))
|
|
@@ -661,10 +671,10 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
661
671
|
if file_path not in Custom.tbe_path_failed:
|
|
662
672
|
# As a single file might include multiply functions
|
|
663
673
|
# we will not try the file path which already failed in previous trials
|
|
664
|
-
mod_spec = importlib.util.spec_from_file_location(
|
|
665
|
-
self.func_name, file_path)
|
|
666
|
-
custom_mod = importlib.util.module_from_spec(mod_spec)
|
|
667
674
|
try:
|
|
675
|
+
mod_spec = importlib.util.spec_from_file_location(
|
|
676
|
+
self.func_name, file_path)
|
|
677
|
+
custom_mod = importlib.util.module_from_spec(mod_spec)
|
|
668
678
|
mod_spec.loader.exec_module(custom_mod)
|
|
669
679
|
except (ImportError, RecursionError):
|
|
670
680
|
Custom.tbe_path_failed.append(file_path)
|
|
@@ -756,16 +766,21 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
756
766
|
|
|
757
767
|
def _update_reg_attrs(self, reg_info):
|
|
758
768
|
"""Update op attrs in reg_info."""
|
|
769
|
+
output_name_list = []
|
|
759
770
|
for _, item in enumerate(reg_info.get("outputs", [])):
|
|
760
|
-
output_name_list = []
|
|
761
771
|
if isinstance(item, dict) and item.get("name"):
|
|
762
772
|
output_name_list.append(item.get("name"))
|
|
773
|
+
if output_name_list:
|
|
763
774
|
self.add_prim_attr("output_names", output_name_list)
|
|
764
775
|
|
|
765
776
|
if isinstance(reg_info.get("op_name"), str):
|
|
766
777
|
self.add_prim_attr("reg_op_name", reg_info.get("op_name"))
|
|
767
778
|
|
|
768
|
-
if self.func_type == "
|
|
779
|
+
if self.func_type == "aicpu":
|
|
780
|
+
self.uniq_name = reg_info["op_name"]
|
|
781
|
+
self.add_prim_attr("uniq_name", self.uniq_name)
|
|
782
|
+
|
|
783
|
+
if self.func_type in ["aot", "aicpu"]:
|
|
769
784
|
if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
|
|
770
785
|
for item in reg_info["attr"]:
|
|
771
786
|
if isinstance(item, dict) and item.get("value") is not None:
|
|
@@ -852,12 +867,6 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
852
867
|
else:
|
|
853
868
|
Custom.registered_func[func_name] = [target]
|
|
854
869
|
|
|
855
|
-
def _get_op_name(self, reg_info):
|
|
856
|
-
if self.func_type == "aicpu":
|
|
857
|
-
self.uniq_name = reg_info["op_name"]
|
|
858
|
-
self.add_prim_attr("uniq_name", self.uniq_name)
|
|
859
|
-
return self.uniq_name
|
|
860
|
-
|
|
861
870
|
def _reformat_reg_info(self, reg_info, target):
|
|
862
871
|
"""Reformat registration information."""
|
|
863
872
|
if not isinstance(reg_info, dict):
|
|
@@ -865,7 +874,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
865
874
|
"'CustomRegOp' to generate the registration information, then pass it to 'reg_info' or "
|
|
866
875
|
"use 'custom_info_register' to bind it to 'func' if 'func' is a function."
|
|
867
876
|
.format(self.log_prefix, reg_info, type(reg_info)))
|
|
868
|
-
reg_info["op_name"] = self.
|
|
877
|
+
reg_info["op_name"] = self.uniq_name
|
|
869
878
|
reg_info["imply_type"] = self._get_imply_type(reg_info, target)
|
|
870
879
|
if not isinstance(reg_info.get("fusion_type"), str) or not reg_info["fusion_type"].strip():
|
|
871
880
|
reg_info["fusion_type"] = "OPAQUE"
|
|
@@ -926,26 +935,37 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
926
935
|
"""Save input_names and attr_names of current func."""
|
|
927
936
|
if not isinstance(reg_info, dict):
|
|
928
937
|
return
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
+
|
|
939
|
+
def _get_value_list(key):
|
|
940
|
+
value = reg_info.get(key, [])
|
|
941
|
+
if not isinstance(value, (list, tuple)):
|
|
942
|
+
value = [value]
|
|
943
|
+
return value
|
|
944
|
+
|
|
945
|
+
tensor_inputs = _get_value_list("inputs")
|
|
946
|
+
attr = _get_value_list("attr")
|
|
947
|
+
input_names = [] # include tensor input names and attr input names
|
|
938
948
|
attr_names = []
|
|
949
|
+
pure_input_names = []
|
|
939
950
|
for item in tensor_inputs:
|
|
940
951
|
if isinstance(item, dict) and item.get("name") is not None:
|
|
941
952
|
input_names.append(item["name"])
|
|
942
|
-
|
|
953
|
+
pure_input_names.append(item["name"])
|
|
954
|
+
# attr is converted from inputs only when graph mode or when inputs name is also in reg info
|
|
955
|
+
attr_to_input_safe = bool(input_names) or context.get_context("mode") == ms.GRAPH_MODE
|
|
943
956
|
for item in attr:
|
|
944
957
|
if isinstance(item, dict) and item.get("name") is not None:
|
|
945
|
-
|
|
958
|
+
# for custom op with function tbe, we always add attrs to inputs as we don't
|
|
959
|
+
# deal with attr value here and leave them to the backend process to fit the
|
|
960
|
+
# usual process of tbe op compiling in mindspore
|
|
961
|
+
# for the rest cases, namely aot and aicpu, if we find values for attrs, we
|
|
962
|
+
# have already add them as prim attr of the op in the fun _update_reg_attrs
|
|
963
|
+
# add attr name to input name only when the value of attr is None in reg info
|
|
964
|
+
# as we need to get values of attrs from inputs
|
|
965
|
+
if attr_to_input_safe and (self.func_type == "tbe" or item.get("value", None) is None):
|
|
946
966
|
input_names.append(item["name"])
|
|
947
967
|
attr_names.append(item["name"])
|
|
948
|
-
cur_attr = {"input_names": input_names, "attr_names": attr_names}
|
|
968
|
+
cur_attr = {"input_names": input_names, "attr_names": attr_names, "pure_input_names": pure_input_names}
|
|
949
969
|
# If func does not have attr, save current attr.
|
|
950
970
|
# Else, check if current attr is same as previous saved one.
|
|
951
971
|
prev_attr_names = attr_names
|
|
@@ -994,7 +1014,12 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
994
1014
|
|
|
995
1015
|
def _update_attr(self):
|
|
996
1016
|
"""Add input_names, attr_names, primitive_target to primitive's attr."""
|
|
997
|
-
|
|
1017
|
+
|
|
1018
|
+
def _add_prim_attr(key):
|
|
1019
|
+
value = func_attr.get(key)
|
|
1020
|
+
if value:
|
|
1021
|
+
self.add_prim_attr(key, value)
|
|
1022
|
+
|
|
998
1023
|
func_attr = {}
|
|
999
1024
|
if callable(self.func):
|
|
1000
1025
|
inputs_num = len(inspect.signature(self.func).parameters)
|
|
@@ -1003,12 +1028,9 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1003
1028
|
elif isinstance(self.func, str):
|
|
1004
1029
|
func_attr = Custom.attr_dict.get(self.func)
|
|
1005
1030
|
if isinstance(func_attr, dict):
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
self.add_prim_attr("input_names", input_names)
|
|
1010
|
-
if attr_names:
|
|
1011
|
-
self.add_prim_attr("attr_names", attr_names)
|
|
1031
|
+
_add_prim_attr("input_names")
|
|
1032
|
+
_add_prim_attr("attr_names")
|
|
1033
|
+
_add_prim_attr("pure_input_names")
|
|
1012
1034
|
self._add_prim_target()
|
|
1013
1035
|
if callable(self.func) and callable(self.out_shape):
|
|
1014
1036
|
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
|
|
@@ -1065,7 +1087,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1065
1087
|
arg_dtype = arg["dtype"]
|
|
1066
1088
|
# if any value is missing from input, disable infer value
|
|
1067
1089
|
enable_infer_value = False
|
|
1068
|
-
if isinstance(arg_dtype, mstype.
|
|
1090
|
+
if isinstance(arg_dtype, mstype.TensorType):
|
|
1069
1091
|
arg_dtype = arg_dtype.element_type()
|
|
1070
1092
|
fake_arg = np.zeros(arg["shape"]).astype(
|
|
1071
1093
|
mstype.dtype_to_nptype(arg_dtype))
|
|
@@ -1075,7 +1097,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1075
1097
|
|
|
1076
1098
|
if hasattr(fake_output, 'shape'):
|
|
1077
1099
|
infer_shape = fake_output.shape
|
|
1078
|
-
infer_dtype = mstype.
|
|
1100
|
+
infer_dtype = mstype.TensorType(mstype.pytype_to_dtype(fake_output.dtype))
|
|
1079
1101
|
else:
|
|
1080
1102
|
infer_shape = (1,)
|
|
1081
1103
|
infer_dtype = mstype.pytype_to_dtype(fake_output.dtype)
|
|
@@ -43,7 +43,7 @@ def _check_summary_param(name, value, class_name):
|
|
|
43
43
|
raise ValueError(f"For '{class_name}', the name must be valid string, but got '{n_value}'.")
|
|
44
44
|
|
|
45
45
|
v_type = value['dtype']
|
|
46
|
-
validator.check_value_type('value', v_type, [type(mstype.
|
|
46
|
+
validator.check_value_type('value', v_type, [type(mstype.tensor_type)], class_name)
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
# Note: The return value of the summary operator is not used,
|
|
@@ -58,7 +58,7 @@ class ScalarSummary(Primitive):
|
|
|
58
58
|
This operator will put a scalar to a summary file with protocol buffer format. It must be used with SummaryRecord
|
|
59
59
|
or SummaryCollector, which specify the directory of the summary file. The summary file can
|
|
60
60
|
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
61
|
-
mindinsight/docs/en/r2.
|
|
61
|
+
mindinsight/docs/en/r2.2/index.html>`_ for details.
|
|
62
62
|
|
|
63
63
|
Inputs:
|
|
64
64
|
- **name** (str) - The name of the input variable, it must not be an empty string.
|
|
@@ -104,6 +104,7 @@ class ScalarSummary(Primitive):
|
|
|
104
104
|
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
|
105
105
|
|
|
106
106
|
self.add_prim_attr("side_effect_io", True)
|
|
107
|
+
self.add_prim_attr("channel_name", "ms_scalar_summary")
|
|
107
108
|
|
|
108
109
|
def __call__(self, *args):
|
|
109
110
|
_cache_summary_data(self.name, args[0], args[1])
|
|
@@ -114,7 +115,7 @@ class ImageSummary(PrimitiveWithInfer):
|
|
|
114
115
|
This operator will put an image tensor to a summary file with protocol buffer format. It must be used with
|
|
115
116
|
SummaryRecord or SummaryCollector, which specify the directory of the summary file. The summary file can
|
|
116
117
|
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
117
|
-
mindinsight/docs/en/r2.
|
|
118
|
+
mindinsight/docs/en/r2.2/index.html>`_ for details.
|
|
118
119
|
|
|
119
120
|
Inputs:
|
|
120
121
|
- **name** (str) - The name of the input variable, it must not be an empty string.
|
|
@@ -153,6 +154,7 @@ class ImageSummary(PrimitiveWithInfer):
|
|
|
153
154
|
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
|
154
155
|
|
|
155
156
|
self.add_prim_attr("side_effect_io", True)
|
|
157
|
+
self.add_prim_attr("channel_name", "ms_image_summary")
|
|
156
158
|
|
|
157
159
|
def __infer__(self, name, value):
|
|
158
160
|
_check_summary_param(name, value, self.__class__.__name__)
|
|
@@ -175,7 +177,7 @@ class TensorSummary(Primitive):
|
|
|
175
177
|
This operator will put a tensor to a summary file with protocol buffer format. It must be used with SummaryRecord
|
|
176
178
|
or SummaryCollector, which specify the directory of the summary file. The summary file can
|
|
177
179
|
be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
178
|
-
mindinsight/docs/en/r2.
|
|
180
|
+
mindinsight/docs/en/r2.2/index.html>`_ for details.
|
|
179
181
|
|
|
180
182
|
Inputs:
|
|
181
183
|
- **name** (str) - The name of the input variable.
|
|
@@ -221,6 +223,7 @@ class TensorSummary(Primitive):
|
|
|
221
223
|
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
|
222
224
|
|
|
223
225
|
self.add_prim_attr("side_effect_io", True)
|
|
226
|
+
self.add_prim_attr("channel_name", "ms_tensor_summary")
|
|
224
227
|
|
|
225
228
|
def __call__(self, *args):
|
|
226
229
|
_cache_summary_data(self.name, args[0], args[1])
|
|
@@ -231,7 +234,7 @@ class HistogramSummary(PrimitiveWithInfer):
|
|
|
231
234
|
This operator will calculate the histogram of a tensor and put it to a summary file with protocol buffer format.
|
|
232
235
|
It must be used with SummaryRecord or SummaryCollector, which specify the directory of the summary file.
|
|
233
236
|
The summary file can be loaded and shown by MindInsight, see `MindInsight documents <https://www.mindspore.cn/
|
|
234
|
-
mindinsight/docs/en/r2.
|
|
237
|
+
mindinsight/docs/en/r2.2/index.html>`_ for details.
|
|
235
238
|
|
|
236
239
|
Inputs:
|
|
237
240
|
- **name** (str) - The name of the input variable.
|
|
@@ -276,6 +279,7 @@ class HistogramSummary(PrimitiveWithInfer):
|
|
|
276
279
|
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
|
277
280
|
|
|
278
281
|
self.add_prim_attr("side_effect_io", True)
|
|
282
|
+
self.add_prim_attr("channel_name", "ms_histogram_summary")
|
|
279
283
|
|
|
280
284
|
def __infer__(self, name, value):
|
|
281
285
|
_check_summary_param(name, value, self.__class__.__name__)
|
|
@@ -380,7 +384,7 @@ class HookBackward(PrimitiveWithInfer):
|
|
|
380
384
|
hook_fn (Function): Python function. hook function.
|
|
381
385
|
cell_id (str, optional): Used to identify whether the function registered by the hook is actually registered on
|
|
382
386
|
the specified cell object. For example, 'nn.Conv2d' is a cell object.
|
|
383
|
-
|
|
387
|
+
Default: ``""``, in this case, the system will automatically
|
|
384
388
|
register a value of `cell_id`.
|
|
385
389
|
The value of `cell_id` currently does not support custom values.
|
|
386
390
|
|
|
@@ -444,7 +448,7 @@ class HookBackward(PrimitiveWithInfer):
|
|
|
444
448
|
|
|
445
449
|
def infer_dtype(self, *inputs_type):
|
|
446
450
|
for dtype in inputs_type:
|
|
447
|
-
validator.check_subclass("input", dtype, [mstype.
|
|
451
|
+
validator.check_subclass("input", dtype, [mstype.tensor_type], self.name)
|
|
448
452
|
if len(inputs_type) == 1:
|
|
449
453
|
return inputs_type[0]
|
|
450
454
|
return inputs_type
|
|
@@ -456,10 +460,19 @@ class Print(Primitive):
|
|
|
456
460
|
|
|
457
461
|
Refer to :func:`mindspore.ops.print_` for more detail.
|
|
458
462
|
|
|
463
|
+
Inputs:
|
|
464
|
+
- **input_x** (Union[Tensor, bool, int, float, str]) - The graph node to attach to.
|
|
465
|
+
Supports multiple inputs which are separated by ','.
|
|
466
|
+
|
|
467
|
+
Outputs:
|
|
468
|
+
Tensor, has the same data type and shape as original `input_x`.
|
|
469
|
+
|
|
459
470
|
Supported Platforms:
|
|
460
471
|
``Ascend`` ``GPU`` ``CPU``
|
|
461
472
|
|
|
462
473
|
Examples:
|
|
474
|
+
>>> import numpy as np
|
|
475
|
+
>>> from mindspore import Tensor, nn
|
|
463
476
|
>>> class PrintDemo(nn.Cell):
|
|
464
477
|
... def __init__(self):
|
|
465
478
|
... super(PrintDemo, self).__init__()
|
|
@@ -503,16 +516,16 @@ class Print(Primitive):
|
|
|
503
516
|
class Assert(PrimitiveWithInfer):
|
|
504
517
|
"""
|
|
505
518
|
Asserts whether the given condition is True.
|
|
506
|
-
If input condition is identified to be
|
|
519
|
+
If input condition is identified to be ``False``, print a list of the tensor in data.
|
|
507
520
|
|
|
508
521
|
Args:
|
|
509
522
|
summarize (int, optional): The number of entries to be printed in each tensor while the given condition is
|
|
510
|
-
identified to be False. Default: 3.
|
|
523
|
+
identified to be ``False`` . Default: ``3`` .
|
|
511
524
|
|
|
512
525
|
Inputs:
|
|
513
526
|
- **condition** (Union[Tensor[bool], bool]) - The condition to be identified.
|
|
514
527
|
- **input_data** (Union[tuple[Tensor], list[Tensor]]) - The tensors to be printed out when the condition
|
|
515
|
-
is
|
|
528
|
+
is ``False``.
|
|
516
529
|
|
|
517
530
|
Raises:
|
|
518
531
|
TypeError: If `summarize` is not an int.
|
|
@@ -560,5 +573,5 @@ class Assert(PrimitiveWithInfer):
|
|
|
560
573
|
def infer_dtype(self, condition, inputs):
|
|
561
574
|
validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name)
|
|
562
575
|
for dtype in inputs:
|
|
563
|
-
validator.check_subclass("input", dtype, [mstype.
|
|
576
|
+
validator.check_subclass("input", dtype, [mstype.tensor_type], self.name)
|
|
564
577
|
return mstype.int32
|