mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -21,15 +21,222 @@ import inspect
|
|
|
21
21
|
import json
|
|
22
22
|
import os
|
|
23
23
|
import functools
|
|
24
|
+
import platform
|
|
25
|
+
import hashlib
|
|
26
|
+
import shutil
|
|
24
27
|
|
|
25
28
|
from mindspore._c_expression import Oplib
|
|
26
29
|
from mindspore import _checkparam as validator
|
|
30
|
+
from mindspore import log as logger
|
|
31
|
+
|
|
32
|
+
if platform.system() == "Linux":
|
|
33
|
+
import fcntl
|
|
27
34
|
|
|
28
35
|
# path of built-in op info register.
|
|
29
36
|
BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
|
|
30
37
|
BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op"
|
|
31
38
|
|
|
32
39
|
|
|
40
|
+
def _get_reg_info_attr(op_info, attr_name):
|
|
41
|
+
"""get attr value"""
|
|
42
|
+
for _, item in enumerate(op_info.get("attr", [])):
|
|
43
|
+
if item.get("name") == attr_name:
|
|
44
|
+
return item.get("defaultValue")
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _CustomInstaller:
|
|
49
|
+
"""save custom op registration information to a json file which will be used by GE"""
|
|
50
|
+
reg_info_hash = [] # used to avoid writing the same reg info to file multiple times
|
|
51
|
+
copied_paths = [] # used to avoid copying the same file multiple times
|
|
52
|
+
|
|
53
|
+
def __init__(self, op_info, func=None):
|
|
54
|
+
self.op_info = op_info
|
|
55
|
+
self.func = func
|
|
56
|
+
self.op_type = op_info.get("op_name") if not func else func.__name__
|
|
57
|
+
vendor_name = "ms"
|
|
58
|
+
custom_dir = os.path.join(os.path.realpath("./"), "vendors", vendor_name)
|
|
59
|
+
self._set_env(custom_dir)
|
|
60
|
+
op_impl_dir = os.path.join(custom_dir, "op_impl")
|
|
61
|
+
self.ai_core_config_dir = os.path.join(op_impl_dir, "ai_core", "tbe", "config")
|
|
62
|
+
self.ai_core_impl_dir = os.path.join(op_impl_dir, "ai_core", "tbe", vendor_name + "_impl")
|
|
63
|
+
self.ai_cpu_config_dir = os.path.join(op_impl_dir, "cpu", "config")
|
|
64
|
+
self.ai_cpu_impl_dir = os.path.join(op_impl_dir, "cpu", "aicpu_kernel", "impl")
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _set_env(custom_opp_path):
|
|
68
|
+
"""set custom file path to env"""
|
|
69
|
+
if not os.environ.get("ASCEND_CUSTOM_OPP_PATH"):
|
|
70
|
+
os.environ["ASCEND_CUSTOM_OPP_PATH"] = custom_opp_path
|
|
71
|
+
else:
|
|
72
|
+
paths = os.environ["ASCEND_CUSTOM_OPP_PATH"].split(':')
|
|
73
|
+
if custom_opp_path not in paths:
|
|
74
|
+
os.environ["ASCEND_CUSTOM_OPP_PATH"] = custom_opp_path + ':' + os.environ["ASCEND_CUSTOM_OPP_PATH"]
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def _create_dir(*dir_names):
|
|
78
|
+
"""create directory"""
|
|
79
|
+
for dir_name in dir_names:
|
|
80
|
+
if not os.path.isdir(dir_name):
|
|
81
|
+
try:
|
|
82
|
+
os.makedirs(dir_name, exist_ok=True)
|
|
83
|
+
except OSError as err:
|
|
84
|
+
if err.errno == 17: # File exists
|
|
85
|
+
pass
|
|
86
|
+
else:
|
|
87
|
+
raise err
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def _copy_file(src_path, dst_dir):
|
|
91
|
+
"""copy file"""
|
|
92
|
+
if not os.path.exists(src_path) or src_path in _CustomInstaller.copied_paths:
|
|
93
|
+
return
|
|
94
|
+
_CustomInstaller.copied_paths.append(src_path)
|
|
95
|
+
if os.path.isfile(src_path):
|
|
96
|
+
lock_file = os.path.join(dst_dir, "file.lock")
|
|
97
|
+
with open(lock_file, "w") as f:
|
|
98
|
+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
|
99
|
+
shutil.copy(src_path, dst_dir)
|
|
100
|
+
|
|
101
|
+
def _check(self):
|
|
102
|
+
"""check if the reg info need written"""
|
|
103
|
+
if platform.system() != "Linux":
|
|
104
|
+
return False
|
|
105
|
+
if not os.environ.get("MS_DEV_CUSTOM_OPP_PATH"):
|
|
106
|
+
# only process the first time import the mindspore module
|
|
107
|
+
return False
|
|
108
|
+
if self.op_info.get("target") in ["GPU", "CPU"]:
|
|
109
|
+
return False
|
|
110
|
+
sha256 = hashlib.sha256()
|
|
111
|
+
value = json.dumps(self.op_info, sort_keys=True).encode()
|
|
112
|
+
sha256.update(value)
|
|
113
|
+
hash_value = sha256.hexdigest()
|
|
114
|
+
if hash_value in _CustomInstaller.reg_info_hash:
|
|
115
|
+
return False
|
|
116
|
+
_CustomInstaller.reg_info_hash.append(hash_value)
|
|
117
|
+
return True
|
|
118
|
+
|
|
119
|
+
def _find_ai_cpu_so_path(self, so_file):
|
|
120
|
+
"""find the absolute path of so"""
|
|
121
|
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
|
122
|
+
search_paths = [current_path + "/../lib", current_path + "/../lib/plugin/ascend"]
|
|
123
|
+
for path in search_paths:
|
|
124
|
+
so_path = os.path.join(path, so_file)
|
|
125
|
+
if os.path.exists(so_path):
|
|
126
|
+
return so_path
|
|
127
|
+
logger.warning("For Custom op '{}', can not find the aicpu so file '{}' in the following directories:\n{}"
|
|
128
|
+
.format(self.op_type, so_file, "\n".join(search_paths)))
|
|
129
|
+
return ""
|
|
130
|
+
|
|
131
|
+
def _gen_ai_core_reg_info(self, imply_path, func_name):
|
|
132
|
+
"""generate reg info"""
|
|
133
|
+
|
|
134
|
+
def _get_dtype_format(idx):
|
|
135
|
+
data_type = []
|
|
136
|
+
data_format = []
|
|
137
|
+
for _, dtype_format in enumerate(self.op_info.get("dtype_format", [])):
|
|
138
|
+
if not dtype_format[idx][0]:
|
|
139
|
+
data_type = None
|
|
140
|
+
else:
|
|
141
|
+
data_type.append(dtype_format[idx][0])
|
|
142
|
+
if not dtype_format[idx][1]:
|
|
143
|
+
data_format = None
|
|
144
|
+
else:
|
|
145
|
+
if dtype_format[idx][1] == "DefaultFormat":
|
|
146
|
+
data_format.append("ND")
|
|
147
|
+
else:
|
|
148
|
+
data_format.append(dtype_format[idx][1])
|
|
149
|
+
return data_type, data_format
|
|
150
|
+
|
|
151
|
+
op_info = {"opFile": {"value": os.path.splitext(os.path.basename(imply_path))[0]},
|
|
152
|
+
"opInterface": {"value": func_name}}
|
|
153
|
+
# attr
|
|
154
|
+
attrs_name = []
|
|
155
|
+
for _, item in enumerate(self.op_info.get("attr", [])):
|
|
156
|
+
attr_name = item.get("name")
|
|
157
|
+
attrs_name.append(attr_name)
|
|
158
|
+
key = "attr_" + attr_name
|
|
159
|
+
op_info[key] = {}
|
|
160
|
+
for k, v in item.items():
|
|
161
|
+
if k != "name":
|
|
162
|
+
op_info[key][k] = v
|
|
163
|
+
if attrs_name:
|
|
164
|
+
op_info["attr"] = {"list": ",".join(attrs_name)}
|
|
165
|
+
# input and output
|
|
166
|
+
inputs = self.op_info.get("inputs", [])
|
|
167
|
+
outputs = self.op_info.get("outputs", [])
|
|
168
|
+
input_num = len(inputs)
|
|
169
|
+
output_num = len(outputs)
|
|
170
|
+
for i in range(input_num + output_num):
|
|
171
|
+
item = inputs[i] if i < input_num else outputs[i - input_num]
|
|
172
|
+
key = "input" if i < input_num else "output"
|
|
173
|
+
key += str(item.get("index"))
|
|
174
|
+
op_info[key] = {"name": item.get("name"),
|
|
175
|
+
"paramType": item.get("paramType", "required"),
|
|
176
|
+
"shape": item.get("shape", "all")}
|
|
177
|
+
dtype, formats = _get_dtype_format(i)
|
|
178
|
+
if dtype:
|
|
179
|
+
op_info[key]["dtype"] = ",".join(dtype)
|
|
180
|
+
if formats:
|
|
181
|
+
op_info[key]["format"] = ",".join(formats)
|
|
182
|
+
return op_info
|
|
183
|
+
|
|
184
|
+
def _gen_ai_cpu_reg_info(self, so_file):
|
|
185
|
+
"""generate reg info"""
|
|
186
|
+
op_info = {"opInfo": {"computeCost": "100",
|
|
187
|
+
"engine": "DNN_VM_AICPU",
|
|
188
|
+
"flagAsync": "False",
|
|
189
|
+
"flagPartial": "False",
|
|
190
|
+
"functionName": "RunCpuKernel",
|
|
191
|
+
"kernelSo": so_file,
|
|
192
|
+
"opKernelLib": "CUSTAICPUKernel",
|
|
193
|
+
"userDefined": "True"}}
|
|
194
|
+
return op_info
|
|
195
|
+
|
|
196
|
+
def _save_op_info(self, dst_dir, file_name, op_info):
|
|
197
|
+
"""save op info file"""
|
|
198
|
+
repo = {}
|
|
199
|
+
save_path = os.path.join(dst_dir, file_name)
|
|
200
|
+
lock_file = os.path.join(dst_dir, "file.lock")
|
|
201
|
+
with open(lock_file, "w") as f:
|
|
202
|
+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
|
203
|
+
if os.path.isfile(save_path):
|
|
204
|
+
with open(save_path, 'r') as fr:
|
|
205
|
+
json_str = fr.read()
|
|
206
|
+
json_str = "{}" if json_str == "" else json_str
|
|
207
|
+
repo = json.loads(json_str)
|
|
208
|
+
repo.update({self.op_type: op_info})
|
|
209
|
+
with os.fdopen(os.open(save_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as fw:
|
|
210
|
+
json.dump(repo, fw, sort_keys=True, indent=4, separators=(',', ':'))
|
|
211
|
+
|
|
212
|
+
def run(self):
|
|
213
|
+
"""save reg info to file"""
|
|
214
|
+
if not self._check():
|
|
215
|
+
return
|
|
216
|
+
so_name = _get_reg_info_attr(self.op_info, "cust_aicpu")
|
|
217
|
+
if so_name:
|
|
218
|
+
_CustomInstaller._create_dir(self.ai_cpu_config_dir, self.ai_cpu_impl_dir)
|
|
219
|
+
# copy so file
|
|
220
|
+
so_file = "lib" + so_name + ".so"
|
|
221
|
+
imply_path = self._find_ai_cpu_so_path(so_file)
|
|
222
|
+
self._copy_file(imply_path, self.ai_cpu_impl_dir)
|
|
223
|
+
# generate and copy reg info file
|
|
224
|
+
op_info = self._gen_ai_cpu_reg_info(so_file)
|
|
225
|
+
self._save_op_info(self.ai_cpu_config_dir, "cust_aicpu_kernel.json", op_info)
|
|
226
|
+
else:
|
|
227
|
+
_CustomInstaller._create_dir(self.ai_core_config_dir, self.ai_core_impl_dir)
|
|
228
|
+
# copy dsl file
|
|
229
|
+
imply_path = os.path.realpath(inspect.getfile(self.func))
|
|
230
|
+
self._copy_file(imply_path, self.ai_core_impl_dir)
|
|
231
|
+
# generate and copy reg info file
|
|
232
|
+
op_info = self._gen_ai_core_reg_info(imply_path, self.func.__name__)
|
|
233
|
+
self._copy_file(imply_path, self.ai_core_impl_dir)
|
|
234
|
+
for arc_name in ["ascend910", "ascend910b"]:
|
|
235
|
+
arc_dir = os.path.join(self.ai_core_config_dir, arc_name)
|
|
236
|
+
_CustomInstaller._create_dir(arc_dir)
|
|
237
|
+
self._save_op_info(arc_dir, "aic-{}-ops-info.json".format(arc_name), op_info)
|
|
238
|
+
|
|
239
|
+
|
|
33
240
|
def op_info_register(op_info):
|
|
34
241
|
r"""
|
|
35
242
|
A decorator which is used to register an operator.
|
|
@@ -125,6 +332,12 @@ def custom_info_register(*reg_info):
|
|
|
125
332
|
|
|
126
333
|
def decorator(func):
|
|
127
334
|
setattr(func, "reg_info", reg_info)
|
|
335
|
+
if reg_info:
|
|
336
|
+
used_reg_info = reg_info[0]
|
|
337
|
+
if isinstance(used_reg_info, dict):
|
|
338
|
+
# ai_cpu should be parsed inside CustomRegOp, skip it here
|
|
339
|
+
if not _get_reg_info_attr(used_reg_info, "cust_aicpu"):
|
|
340
|
+
_CustomInstaller(used_reg_info, func).run()
|
|
128
341
|
|
|
129
342
|
@functools.wraps(func)
|
|
130
343
|
def wrapper(*args, **kwargs):
|
|
@@ -140,7 +353,7 @@ class RegOp:
|
|
|
140
353
|
Base class for op info register.
|
|
141
354
|
|
|
142
355
|
Args:
|
|
143
|
-
op_name (str): Name of
|
|
356
|
+
op_name (str): Name of operator.
|
|
144
357
|
"""
|
|
145
358
|
|
|
146
359
|
def __init__(self, op_name=""):
|
|
@@ -446,10 +659,10 @@ class AkgCpuRegOp(AkgRegOp):
|
|
|
446
659
|
|
|
447
660
|
class AiCPURegOp(CpuRegOp):
|
|
448
661
|
r"""
|
|
449
|
-
Class for AiCPU operator information
|
|
662
|
+
Class for AiCPU operator information registration.
|
|
450
663
|
|
|
451
664
|
Args:
|
|
452
|
-
op_name (str):
|
|
665
|
+
op_name (str): Name of operator.
|
|
453
666
|
|
|
454
667
|
Examples:
|
|
455
668
|
>>> from mindspore.ops import AiCPURegOp, DataType
|
|
@@ -481,14 +694,15 @@ class AiCPURegOp(CpuRegOp):
|
|
|
481
694
|
|
|
482
695
|
class TBERegOp(RegOp):
|
|
483
696
|
r"""
|
|
484
|
-
Class for TBE operator information
|
|
697
|
+
Class for TBE operator information registration. TBE (Tensor Boost Engine) is the Ascend operator development
|
|
698
|
+
tool, which is extended on the basis of the TVM framework to develop custom operators.
|
|
485
699
|
|
|
486
700
|
Args:
|
|
487
|
-
op_name (str):
|
|
701
|
+
op_name (str): Name of operator.
|
|
488
702
|
|
|
489
703
|
Examples:
|
|
490
|
-
>>>
|
|
491
|
-
>>> op_name_op_info =
|
|
704
|
+
>>> from mindspore.ops import TBERegOp, DataType
|
|
705
|
+
>>> op_name_op_info = TBERegOp("OpName") \
|
|
492
706
|
... .fusion_type("ELEMWISE") \
|
|
493
707
|
... .async_flag(False) \
|
|
494
708
|
... .binfile_name("op_name.so") \
|
|
@@ -505,14 +719,14 @@ class TBERegOp(RegOp):
|
|
|
505
719
|
... .input(0, "x2", None, "required", None) \
|
|
506
720
|
... .input(1, "axis", None, "required", None) \
|
|
507
721
|
... .output(0, "y", True, "required", "all") \
|
|
508
|
-
... .real_input_index([1, 0])
|
|
509
|
-
... .input_to_attr_index([2])
|
|
510
|
-
... .unknown_shape_formats(["ND", "ND", "ND", "ND"])
|
|
722
|
+
... .real_input_index([1, 0]) \
|
|
723
|
+
... .input_to_attr_index([2]) \
|
|
724
|
+
... .unknown_shape_formats(["ND", "ND", "ND", "ND"]) \
|
|
511
725
|
... .reshape_type("NC") \
|
|
512
726
|
... .is_dynamic_format(True) \
|
|
513
|
-
... .dtype_format(DataType.F16_None, DataType.F16_None) \
|
|
514
|
-
... .dtype_format(DataType.F32_None, DataType.F32_None) \
|
|
515
|
-
... .dtype_format(DataType.I32_None, DataType.I32_None) \
|
|
727
|
+
... .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
|
728
|
+
... .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
|
729
|
+
... .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None) \
|
|
516
730
|
... .get_op_info()
|
|
517
731
|
>>>
|
|
518
732
|
"""
|
|
@@ -830,7 +1044,7 @@ class CustomRegOp(RegOp):
|
|
|
830
1044
|
|
|
831
1045
|
Tutorial Examples:
|
|
832
1046
|
- `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
|
|
833
|
-
<https://mindspore.cn/tutorials/experts/en/r2.
|
|
1047
|
+
<https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
|
|
834
1048
|
defining-custom-operator-of-aicpu-type>`_
|
|
835
1049
|
"""
|
|
836
1050
|
param_list = [index, name, param_type]
|
|
@@ -870,7 +1084,7 @@ class CustomRegOp(RegOp):
|
|
|
870
1084
|
|
|
871
1085
|
Tutorial Examples:
|
|
872
1086
|
- `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
|
|
873
|
-
<https://mindspore.cn/tutorials/experts/en/r2.
|
|
1087
|
+
<https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
|
|
874
1088
|
defining-custom-operator-of-aicpu-type>`_
|
|
875
1089
|
"""
|
|
876
1090
|
param_list = [index, name, param_type]
|
|
@@ -898,7 +1112,7 @@ class CustomRegOp(RegOp):
|
|
|
898
1112
|
|
|
899
1113
|
Tutorial Examples:
|
|
900
1114
|
- `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
|
|
901
|
-
<https://mindspore.cn/tutorials/experts/en/r2.
|
|
1115
|
+
<https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
|
|
902
1116
|
defining-custom-operator-of-aicpu-type>`_
|
|
903
1117
|
"""
|
|
904
1118
|
io_nums = len(self.inputs) + len(self.outputs)
|
|
@@ -955,7 +1169,7 @@ class CustomRegOp(RegOp):
|
|
|
955
1169
|
|
|
956
1170
|
Tutorial Examples:
|
|
957
1171
|
- `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
|
|
958
|
-
<https://mindspore.cn/tutorials/experts/en/r2.
|
|
1172
|
+
<https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
|
|
959
1173
|
defining-custom-operator-of-aicpu-type>`_
|
|
960
1174
|
"""
|
|
961
1175
|
param_list = [name, param_type, value_type, default_value]
|
|
@@ -981,7 +1195,7 @@ class CustomRegOp(RegOp):
|
|
|
981
1195
|
|
|
982
1196
|
Tutorial Examples:
|
|
983
1197
|
- `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
|
|
984
|
-
<https://mindspore.cn/tutorials/experts/en/r2.
|
|
1198
|
+
<https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
|
|
985
1199
|
defining-custom-operator-of-aicpu-type>`_
|
|
986
1200
|
"""
|
|
987
1201
|
if target is not None:
|
|
@@ -996,7 +1210,7 @@ class CustomRegOp(RegOp):
|
|
|
996
1210
|
|
|
997
1211
|
Tutorial Examples:
|
|
998
1212
|
- `Custom Operators (Custom-based) - Defining Custom Operator of aicpu Type
|
|
999
|
-
<https://mindspore.cn/tutorials/experts/en/r2.
|
|
1213
|
+
<https://mindspore.cn/tutorials/experts/en/r2.2/operation/op_custom.html#
|
|
1000
1214
|
defining-custom-operator-of-aicpu-type>`_
|
|
1001
1215
|
"""
|
|
1002
1216
|
op_info = {}
|
|
@@ -1004,6 +1218,8 @@ class CustomRegOp(RegOp):
|
|
|
1004
1218
|
if isinstance(k, str) and k.endswith('_'):
|
|
1005
1219
|
k = k.rstrip('_')
|
|
1006
1220
|
op_info[k] = v
|
|
1221
|
+
if _get_reg_info_attr(op_info, "cust_aicpu"):
|
|
1222
|
+
_CustomInstaller(op_info).run()
|
|
1007
1223
|
return op_info
|
|
1008
1224
|
|
|
1009
1225
|
|
|
@@ -21,7 +21,7 @@ A collection of operators to build neural networks or to compute functions.
|
|
|
21
21
|
|
|
22
22
|
from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
|
|
23
23
|
MapUniform, DynamicAssign, PadAndShift)
|
|
24
|
-
from ._inner_ops import (MatmulDDS, DSDMatmul, Cummin, ExtractImagePatches)
|
|
24
|
+
from ._inner_ops import (MatmulDDS, DSDMatmul, Cummin, ExtractImagePatches, SelectView, CopyWithSlice)
|
|
25
25
|
from ._quant_ops import *
|
|
26
26
|
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
|
27
27
|
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
|
@@ -29,20 +29,20 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg
|
|
|
29
29
|
LoadIm2Col, UpdateThorGradient, CholeskyTrsm,
|
|
30
30
|
DetTriangle, ProdForceSeA)
|
|
31
31
|
from ._ms_kernel import (ms_kernel, kernel)
|
|
32
|
-
from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchToSpace,
|
|
32
|
+
from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchToSpace,
|
|
33
33
|
BatchToSpaceNDV2, BroadcastTo, Cast, Coalesce, Concat, Cummax, DType, DepthToSpace, Diag,
|
|
34
|
-
DiagPart,
|
|
35
|
-
Eye, Fill, Gather, GatherD, GatherNd,
|
|
34
|
+
DiagPart, EditDistance, EmbeddingLookup, ExpandDims, ExtractVolumePatches,
|
|
35
|
+
Eye, Fill, Gather, GatherD, GatherNd, Identity, Im2Col, InvertPermutation,
|
|
36
36
|
LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
|
|
37
|
-
|
|
38
|
-
ReverseSequence, ReverseV2, Rint,
|
|
37
|
+
Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
|
|
38
|
+
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd,
|
|
39
39
|
ScatterDiv, ScatterMax, ScatterMin, ScatterMul, ScatterNd, ScatterNdAdd, ScatterNdDiv,
|
|
40
|
-
ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate,
|
|
40
|
+
ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterSub,
|
|
41
41
|
ScatterUpdate, SearchSorted, Select, Shape, Size, Slice, Sort, SpaceToBatch, SpaceToBatchND,
|
|
42
42
|
SpaceToDepth, SparseGatherV2, Split, SplitV, Squeeze, Stack, StridedSlice, TensorScatterAdd,
|
|
43
43
|
TensorScatterDiv, TensorScatterMax, TensorScatterMin, TensorScatterMul, TensorScatterSub,
|
|
44
44
|
TensorScatterUpdate, TensorShape, Tile, TopK, TransShape, Transpose, TupleToArray, Unique,
|
|
45
|
-
UniqueWithPad,
|
|
45
|
+
UniqueWithPad, UnsortedSegmentMax, UnsortedSegmentMin, UnsortedSegmentProd,
|
|
46
46
|
UnsortedSegmentSum, Unstack, UpperBound, Zeros, ZerosLike, AffineGrid, Bincount, CheckNumerics,
|
|
47
47
|
HammingWindow, IdentityN, IndexFill, LeftShift, ListDiff, LogSpace, MatrixBandPart,
|
|
48
48
|
MatrixDiagPartV3, MatrixDiagV3, MatrixSetDiagV3, NonZero, Expand, Col2Im, ConjugateTranspose,
|
|
@@ -69,7 +69,7 @@ from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerA
|
|
|
69
69
|
from .linalg_ops import (Svd, Geqrf)
|
|
70
70
|
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
|
|
71
71
|
BitwiseAnd, BitwiseOr, Ger, BitwiseXor, Inv, Invert, ApproximateEqual,
|
|
72
|
-
InplaceAdd, InplaceSub,
|
|
72
|
+
InplaceAdd, InplaceSub, InplaceUpdateV2,
|
|
73
73
|
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
|
|
74
74
|
Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
|
|
75
75
|
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
|
@@ -79,7 +79,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|
|
79
79
|
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
|
80
80
|
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
|
81
81
|
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Addcdiv,
|
|
82
|
-
Addcmul, Square, Sub,
|
|
82
|
+
Addcmul, Square, Sub, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps,
|
|
83
83
|
Tan, MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose, LuSolve,
|
|
84
84
|
CholeskyInverse, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselK1, BesselK1e, BesselY0,
|
|
85
85
|
BesselY1, Bucketize, Cauchy, Cholesky, CholeskySolve, Betainc,
|
|
@@ -92,14 +92,14 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|
|
92
92
|
from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam,
|
|
93
93
|
ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
|
|
94
94
|
DepthwiseConv2dNative,
|
|
95
|
-
|
|
96
|
-
InstanceNorm,
|
|
97
|
-
GeLU,
|
|
95
|
+
Dropout, Dropout2D, Dropout3D, Flatten,
|
|
96
|
+
InstanceNorm,
|
|
97
|
+
GeLU, FastGeLU, Elu, CeLU,
|
|
98
98
|
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCLossV2Grad, CTCGreedyDecoder,
|
|
99
99
|
LogSoftmax, MaxPool3D, AvgPool3D,
|
|
100
100
|
MaxPool, DataFormatDimMap,
|
|
101
101
|
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
|
102
|
-
|
|
102
|
+
MaxPoolWithArgmaxV2, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2,
|
|
103
103
|
HSwish, HSigmoid,
|
|
104
104
|
ResizeBilinear, Sigmoid, SeLU, HShrink, ApplyKerasMomentum,
|
|
105
105
|
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
|
|
@@ -115,13 +115,13 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
|
|
115
115
|
ApplyAdamWithAmsgrad, ApplyAdamWithAmsgradV2, AdaptiveAvgPool3D, AdaptiveMaxPool2D,
|
|
116
116
|
AdaptiveMaxPool3D,
|
|
117
117
|
GridSampler3D, MaxPool3DWithArgmax, MaxUnpool2D, NuclearNorm, NthElement, MultilabelMarginLoss,
|
|
118
|
-
Dilation2D, DataFormatVecPermute, DeformableOffsets, FractionalAvgPool,
|
|
118
|
+
Dilation2D, DataFormatVecPermute, DeformableOffsets, Dense, FractionalAvgPool,
|
|
119
119
|
FractionalMaxPool, FractionalMaxPool3DWithFixedKsize, FractionalMaxPoolWithFixedKsize,
|
|
120
120
|
GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3, ChannelShuffle,
|
|
121
121
|
GLU, MaxUnpool3D, Pdist)
|
|
122
122
|
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
|
123
123
|
ConfusionMatrix, UpdateState, Load, StopGradient,
|
|
124
|
-
CheckValid, Partial, Depend,
|
|
124
|
+
CheckValid, Partial, Depend, Push, Pull, PyExecute, PyFunc, _DynamicLossScale,
|
|
125
125
|
SampleDistortedBoundingBoxV2)
|
|
126
126
|
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
|
|
127
127
|
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
|
@@ -129,8 +129,13 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamm
|
|
|
129
129
|
ParameterizedTruncatedNormal, RandomPoisson, MultinomialWithReplacement, RandomShuffle,
|
|
130
130
|
RandpermV2)
|
|
131
131
|
from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
|
|
132
|
-
from .sparse_ops import (
|
|
132
|
+
from .sparse_ops import (
|
|
133
|
+
SparseToDense, SparseTensorDenseMatmul, SparseTensorDenseAdd, SparseSlice)
|
|
133
134
|
from .spectral_ops import (BartlettWindow, BlackmanWindow)
|
|
135
|
+
from ..deprecated import (identity, DropoutDoMask, MaxPoolWithArgmax,
|
|
136
|
+
BNTrainingReduce, BNTrainingUpdate, DropoutGenMask, Gelu, FastGelu,
|
|
137
|
+
TensorAdd, InplaceUpdate, ScatterNonAliasingAdd,
|
|
138
|
+
BatchToSpaceND, Unpack, GatherV2, DynamicShape, ScalarToArray, Pack)
|
|
134
139
|
|
|
135
140
|
__all__ = [
|
|
136
141
|
'HSVToRGB',
|
|
@@ -616,7 +621,10 @@ __all__ = [
|
|
|
616
621
|
"CumulativeLogsumexp",
|
|
617
622
|
"DataFormatVecPermute",
|
|
618
623
|
"DeformableOffsets",
|
|
624
|
+
"Dense",
|
|
619
625
|
"ExtractImagePatches",
|
|
626
|
+
"SelectView",
|
|
627
|
+
"CopyWithSlice",
|
|
620
628
|
"FillDiagonal",
|
|
621
629
|
"Fills",
|
|
622
630
|
"Gcd",
|
|
@@ -390,7 +390,7 @@ class Conv2DBackpropFilter(Primitive):
|
|
|
390
390
|
stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
|
|
391
391
|
dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
|
|
392
392
|
group (int): Splits input into groups. Default: 1.
|
|
393
|
-
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'
|
|
393
|
+
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW', \
|
|
394
394
|
default is 'NCHW'.
|
|
395
395
|
|
|
396
396
|
Returns:
|
|
@@ -636,7 +636,7 @@ class EinsumGrad(PrimitiveWithInfer):
|
|
|
636
636
|
|
|
637
637
|
@prim_attr_register
|
|
638
638
|
def __init__(self, equation):
|
|
639
|
-
|
|
639
|
+
pass
|
|
640
640
|
|
|
641
641
|
def infer_shape(self, x_shapes, dout_shape):
|
|
642
642
|
out_shape = ()
|
|
@@ -1521,9 +1521,11 @@ class LSTMGrad(Primitive):
|
|
|
1521
1521
|
"""Computes the data and weight gradients of LSTM."""
|
|
1522
1522
|
|
|
1523
1523
|
@prim_attr_register
|
|
1524
|
-
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
|
1524
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0):
|
|
1525
1525
|
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1526
1526
|
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1527
|
+
self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT,
|
|
1528
|
+
'proj_size', self.name)
|
|
1527
1529
|
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1528
1530
|
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1529
1531
|
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
@@ -2573,7 +2575,12 @@ class MultilabelMarginLossGrad(Primitive):
|
|
|
2573
2575
|
Compute the gradients of MultilabelMarginLoss operation.
|
|
2574
2576
|
|
|
2575
2577
|
Args:
|
|
2576
|
-
reduction (str): Apply specific reduction method to the output: 'none', 'mean',
|
|
2578
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2579
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
2580
|
+
|
|
2581
|
+
- ``'none'``: no reduction will be applied.
|
|
2582
|
+
- ``'mean'``: compute and return the mean of elements in the output.
|
|
2583
|
+
- ``'sum'``: the output elements will be summed.
|
|
2577
2584
|
|
|
2578
2585
|
Inputs:
|
|
2579
2586
|
- **y_grad** (Tensor) - The gradients of loss to output of MultilabelMarginLoss function, with
|
|
@@ -2595,7 +2602,7 @@ class MultilabelMarginLossGrad(Primitive):
|
|
|
2595
2602
|
TypeError: If dtype of `y_grad` is not the same as `x`.
|
|
2596
2603
|
ValueError: If length of shape of `x` is neither 1 nor 2.
|
|
2597
2604
|
ValueError: If shape of `x` is not the same as `target`.
|
|
2598
|
-
ValueError: If `reduction` is not one of 'none'
|
|
2605
|
+
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
2599
2606
|
ValueError: If shape of `y_grad` is not the same as forward output `y`.
|
|
2600
2607
|
|
|
2601
2608
|
Supported Platforms:
|
|
@@ -2862,7 +2869,9 @@ class Dilation2DBackpropFilter(Primitive):
|
|
|
2862
2869
|
self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
|
|
2863
2870
|
self.add_prim_attr("pad_mode", self.pad_mode.upper())
|
|
2864
2871
|
self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
|
|
2865
|
-
|
|
2872
|
+
def is_in_range(x):
|
|
2873
|
+
return 1 <= x <= 255
|
|
2874
|
+
if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
|
|
2866
2875
|
raise ValueError(f"For '{self.name}', size of stride is not supported, "
|
|
2867
2876
|
f'stride should be in the range of [1, 255], '
|
|
2868
2877
|
f'but got stride_h: `{self.stride[2]}`, stride_w: `{self.stride[3]}`.')
|
|
@@ -2917,7 +2926,12 @@ class MultiMarginLossGrad(Primitive):
|
|
|
2917
2926
|
Args:
|
|
2918
2927
|
p (int): Optional. The norm degree for pairwise distance.Should be 1 or 2. Default: 1.
|
|
2919
2928
|
margin (float): Optional. A parameter to change pairwise distance. Default: 1.0.
|
|
2920
|
-
reduction (str): Apply specific reduction method to the output: 'none', 'mean',
|
|
2929
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2930
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
2931
|
+
|
|
2932
|
+
- ``'none'``: no reduction will be applied.
|
|
2933
|
+
- ``'mean'``: compute and return the weighted mean of elements in the output.
|
|
2934
|
+
- ``'sum'``: the output elements will be summed.
|
|
2921
2935
|
|
|
2922
2936
|
Inputs:
|
|
2923
2937
|
- **y_grad** (Tensor) - If it's not a scalar, the shape of 'y_grad' :math:`(N, C)`.
|
|
@@ -3818,3 +3832,34 @@ class WKVGrad(Primitive):
|
|
|
3818
3832
|
"""Initialize WKVGrad."""
|
|
3819
3833
|
self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
|
|
3820
3834
|
outputs=["gw", "gu", "gk", "gv"])
|
|
3835
|
+
|
|
3836
|
+
|
|
3837
|
+
class FlashAttentionScoreGrad(Primitive):
|
|
3838
|
+
r"""
|
|
3839
|
+
Calculates the gradient of FlashAttentionScore operation.
|
|
3840
|
+
.. warning::
|
|
3841
|
+
This is an experimental API that is subject to change or deletion.
|
|
3842
|
+
|
|
3843
|
+
Supported Platforms:
|
|
3844
|
+
``Ascend``
|
|
3845
|
+
"""
|
|
3846
|
+
@prim_attr_register
|
|
3847
|
+
def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=1,
|
|
3848
|
+
input_layout='BSH'):
|
|
3849
|
+
"""Initialize FlashAttentionScoreGrad."""
|
|
3850
|
+
validator.check_value_type('head_num', head_num, [int], self.name)
|
|
3851
|
+
validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
|
|
3852
|
+
validator.check_float(keep_prob, 0.0, validator.GE, "keep_prob", self.name)
|
|
3853
|
+
validator.check_float(keep_prob, 1.0, validator.LE, "keep_prob", self.name)
|
|
3854
|
+
validator.check_value_type('scale_value', scale_value, [float], self.name)
|
|
3855
|
+
validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
|
|
3856
|
+
validator.check_value_type('next_tokens', next_tokens, [int], self.name)
|
|
3857
|
+
validator.check_value_type('inner_precise', inner_precise, [int], self.name)
|
|
3858
|
+
if inner_precise not in [0, 1]:
|
|
3859
|
+
raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
|
|
3860
|
+
validator.check_value_type('input_layout', input_layout, [str], self.name)
|
|
3861
|
+
if input_layout not in ["BSH"]:
|
|
3862
|
+
raise ValueError(f"Attribute 'input_layout' must be either 'bsh' or 'sbh', but got {input_layout}")
|
|
3863
|
+
self.init_prim_io_names(inputs=['query', 'key', 'value', 'attn_mask', 'attention_in', 'softmax_max',
|
|
3864
|
+
'softmax_sum', 'dy', 'drop_mask', 'real_shift', "padding_mask", 'softmax_out'],
|
|
3865
|
+
outputs=['dq', 'dk', 'dv'])
|