mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/dense.py
CHANGED
|
@@ -77,6 +77,7 @@ class BiDense(Cell):
|
|
|
77
77
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
|
|
78
78
|
The values of str refer to the function `initializer`. Default: ``None`` .
|
|
79
79
|
has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
|
|
80
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
80
81
|
|
|
81
82
|
Shape:
|
|
82
83
|
- **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
|
|
@@ -90,8 +91,8 @@ class BiDense(Cell):
|
|
|
90
91
|
are the same shape as the inputs.
|
|
91
92
|
|
|
92
93
|
Dtype:
|
|
93
|
-
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2
|
|
94
|
-
- **
|
|
94
|
+
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
|
|
95
|
+
- **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
|
|
95
96
|
- **output** (Tensor) - With the same dtype as the inputs.
|
|
96
97
|
|
|
97
98
|
Weights:
|
|
@@ -133,7 +134,8 @@ class BiDense(Cell):
|
|
|
133
134
|
out_channels,
|
|
134
135
|
weight_init=None,
|
|
135
136
|
bias_init=None,
|
|
136
|
-
has_bias=True
|
|
137
|
+
has_bias=True,
|
|
138
|
+
dtype=mstype.float32):
|
|
137
139
|
super().__init__()
|
|
138
140
|
self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
|
|
139
141
|
self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
|
|
@@ -156,7 +158,8 @@ class BiDense(Cell):
|
|
|
156
158
|
f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
|
|
157
159
|
f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
|
|
158
160
|
f"'in2_channels': {in2_channels}")
|
|
159
|
-
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)
|
|
161
|
+
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
|
|
162
|
+
'weight')
|
|
160
163
|
|
|
161
164
|
if self.has_bias:
|
|
162
165
|
if bias_init is None:
|
|
@@ -166,7 +169,7 @@ class BiDense(Cell):
|
|
|
166
169
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
|
|
167
170
|
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
|
168
171
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
|
169
|
-
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
|
172
|
+
self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
|
|
170
173
|
self.bias_add = P.BiasAdd()
|
|
171
174
|
self.matmul = P.MatMul()
|
|
172
175
|
|
mindspore/nn/layer/embedding.py
CHANGED
|
@@ -64,11 +64,13 @@ class Embedding(Cell):
|
|
|
64
64
|
embedding_size (int): The size of each embedding vector.
|
|
65
65
|
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
|
|
66
66
|
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
67
|
-
Refer to class `initializer
|
|
68
|
-
|
|
67
|
+
Refer to class `mindspore.common.initializer
|
|
68
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
69
|
+
for the values of string when a string is specified. Default: ``'normal'`` .
|
|
69
70
|
dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
|
|
70
71
|
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
|
|
71
72
|
will be initialized to zero. Default: ``None`` . The feature is inactivated.
|
|
73
|
+
|
|
72
74
|
Inputs:
|
|
73
75
|
- **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
|
|
74
76
|
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
|
|
@@ -145,9 +147,8 @@ class Embedding(Cell):
|
|
|
145
147
|
return output
|
|
146
148
|
|
|
147
149
|
def extend_repr(self):
|
|
148
|
-
|
|
149
|
-
self.
|
|
150
|
-
return s
|
|
150
|
+
return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
|
|
151
|
+
f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
@_primexpr
|
|
@@ -190,6 +191,7 @@ class EmbeddingLookup(Cell):
|
|
|
190
191
|
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
|
191
192
|
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
|
192
193
|
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
|
194
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
193
195
|
|
|
194
196
|
Inputs:
|
|
195
197
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
@@ -231,7 +233,7 @@ class EmbeddingLookup(Cell):
|
|
|
231
233
|
|
|
232
234
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
233
235
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
234
|
-
max_norm=None, sparse=True, vocab_cache_size=0):
|
|
236
|
+
max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
|
|
235
237
|
"""Initialize EmbeddingLookup."""
|
|
236
238
|
super(EmbeddingLookup, self).__init__()
|
|
237
239
|
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
|
@@ -255,8 +257,8 @@ class EmbeddingLookup(Cell):
|
|
|
255
257
|
if enable_ps:
|
|
256
258
|
self._process_vocab_cache(slice_mode)
|
|
257
259
|
self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
|
|
258
|
-
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]
|
|
259
|
-
|
|
260
|
+
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
|
|
261
|
+
dtype=dtype), name='embedding_table')
|
|
260
262
|
parallel_mode = _get_parallel_mode()
|
|
261
263
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
262
264
|
self.gather_revert = P.Gather()
|
|
@@ -267,7 +269,7 @@ class EmbeddingLookup(Cell):
|
|
|
267
269
|
if is_auto_parallel:
|
|
268
270
|
self.unique = P.Unique().shard(((1,),))
|
|
269
271
|
if self.cache_enable and enable_ps:
|
|
270
|
-
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
|
|
272
|
+
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
|
|
271
273
|
if is_auto_parallel:
|
|
272
274
|
self.unique.add_prim_attr('cache_enable', True)
|
|
273
275
|
indices_shape_size = 2
|
|
@@ -310,8 +312,8 @@ class EmbeddingLookup(Cell):
|
|
|
310
312
|
else:
|
|
311
313
|
if is_auto_parallel:
|
|
312
314
|
support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
|
|
313
|
-
raise ValueError("For '{}', the 'slice_mode' must be in {}, "
|
|
314
|
-
"but got \"{}\"."
|
|
315
|
+
raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
|
|
316
|
+
f"but got \"{slice_mode}\".")
|
|
315
317
|
if self.cache_enable and not enable_ps:
|
|
316
318
|
raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
|
|
317
319
|
self.embedding_table.unique = self.forward_unique
|
|
@@ -354,7 +356,8 @@ class EmbeddingLookup(Cell):
|
|
|
354
356
|
if _is_role_worker():
|
|
355
357
|
self.vocab_size = self.vocab_cache_size
|
|
356
358
|
|
|
357
|
-
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init
|
|
359
|
+
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
|
|
360
|
+
dtype=mstype.float32):
|
|
358
361
|
"""PS embeddingLookup cache enable set."""
|
|
359
362
|
if self.sparse:
|
|
360
363
|
self.forward_unique = True
|
|
@@ -368,10 +371,10 @@ class EmbeddingLookup(Cell):
|
|
|
368
371
|
if _enable_distributed_mindrt():
|
|
369
372
|
self.rank_id = get_rank()
|
|
370
373
|
if self.is_ps_server:
|
|
371
|
-
self._slice_pserver_embeddings("zeros")
|
|
374
|
+
self._slice_pserver_embeddings("zeros", dtype=dtype)
|
|
372
375
|
self._set_cache_enable_and_key_for_pserver(param_key)
|
|
373
376
|
|
|
374
|
-
def _slice_pserver_embeddings(self, param_init):
|
|
377
|
+
def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
|
|
375
378
|
'''
|
|
376
379
|
Method to slice embedding tables on Parameter Servers.
|
|
377
380
|
It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
|
|
@@ -399,7 +402,7 @@ class EmbeddingLookup(Cell):
|
|
|
399
402
|
for i in range(server_num):
|
|
400
403
|
self.embedding_table_list.append(Parameter(initializer(param_init,
|
|
401
404
|
[self.embedding_table_vocab_dim_list[i],
|
|
402
|
-
self.embedding_size]),
|
|
405
|
+
self.embedding_size], dtype=dtype),
|
|
403
406
|
name="embedding_table_server_" + str(i)))
|
|
404
407
|
|
|
405
408
|
self.embedding_offset.append(offset)
|
|
@@ -505,12 +508,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
505
508
|
:class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
|
|
506
509
|
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
|
|
507
510
|
Default: ``None`` .
|
|
508
|
-
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
509
|
-
|
|
511
|
+
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
|
|
512
|
+
Default: ``None`` .
|
|
510
513
|
sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
|
|
511
514
|
Default: ``True`` .
|
|
512
515
|
operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
|
|
513
516
|
``'MAX'`` . Default: ``'SUM'`` .
|
|
517
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
514
518
|
|
|
515
519
|
Inputs:
|
|
516
520
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
|
|
@@ -529,12 +533,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
529
533
|
TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
|
|
530
534
|
TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
|
|
531
535
|
ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
|
|
532
|
-
ValueError: If `target` is neither 'CPU' nor 'DEVICE'
|
|
533
|
-
ValueError: If `slice_mode` is not one of 'batch_slice'
|
|
534
|
-
'table_column_slice'.
|
|
535
|
-
ValueError: If `sparse` is False and `target` is 'CPU'.
|
|
536
|
-
ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
|
|
537
|
-
ValueError: If `operator` is not one of 'SUM'
|
|
536
|
+
ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
|
|
537
|
+
ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
|
|
538
|
+
``'table_column_slice'`` .
|
|
539
|
+
ValueError: If `sparse` is False and `target` is ``'CPU'`` .
|
|
540
|
+
ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
|
|
541
|
+
ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
|
|
538
542
|
|
|
539
543
|
Supported Platforms:
|
|
540
544
|
``Ascend`` ``GPU``
|
|
@@ -555,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
555
559
|
OPERATOR_MAX = 'MAX'
|
|
556
560
|
|
|
557
561
|
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
|
|
558
|
-
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'
|
|
562
|
+
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
|
|
563
|
+
dtype=mstype.float32):
|
|
559
564
|
"""Initialize MultiFieldEmbeddingLookup."""
|
|
560
565
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
|
561
|
-
slice_mode, feature_num_list, max_norm, sparse)
|
|
566
|
+
slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
|
|
562
567
|
self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
|
|
563
568
|
self.operator = operator
|
|
564
569
|
|
|
@@ -622,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
622
627
|
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
|
|
623
628
|
else:
|
|
624
629
|
if is_auto_parallel:
|
|
625
|
-
raise ValueError(
|
|
626
|
-
|
|
630
|
+
raise ValueError(
|
|
631
|
+
f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
|
|
632
|
+
f"and 'table_column_slice'], but got {str(slice_mode)}.")
|
|
627
633
|
|
|
628
634
|
# Min value for fp32
|
|
629
635
|
self.negative_inf_value = -3.402823466E+38
|
|
@@ -17,12 +17,11 @@ A FlashAttention Layer.
|
|
|
17
17
|
"""
|
|
18
18
|
import math
|
|
19
19
|
|
|
20
|
-
import mindspore.
|
|
21
|
-
from mindspore import ops
|
|
22
|
-
from mindspore.common import dtype as mstype
|
|
20
|
+
import mindspore.common.dtype as mstype
|
|
23
21
|
from mindspore.common.tensor import Tensor
|
|
22
|
+
from mindspore import ops
|
|
24
23
|
from mindspore.nn.cell import Cell
|
|
25
|
-
from mindspore.ops.
|
|
24
|
+
from mindspore.ops.operations.nn_ops import FlashAttentionScore
|
|
26
25
|
|
|
27
26
|
__all__ = ['FlashAttention']
|
|
28
27
|
|
|
@@ -45,25 +44,25 @@ class FlashAttention(Cell):
|
|
|
45
44
|
Default 65536.
|
|
46
45
|
next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
|
|
47
46
|
Default 65536.
|
|
48
|
-
tiling_stgy_name(str): A str to define tiling strategy of flash attention.
|
|
49
47
|
dp(int): data parallel.
|
|
50
48
|
Default 1.
|
|
51
49
|
mp(int): model parallel.
|
|
52
50
|
Default 1.
|
|
53
|
-
high_precision(bool): This mode has higher precision but some performance loss.
|
|
51
|
+
high_precision(bool): This mode has higher precision but some performance loss. Only take effect on Ascend910A.
|
|
54
52
|
Default False.
|
|
55
53
|
have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
|
|
56
54
|
Default True
|
|
57
55
|
alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
|
|
58
56
|
Default: False
|
|
57
|
+
use_mqa(bool): Using MQA if True, only take effect under 910B. Default: False.
|
|
59
58
|
|
|
60
59
|
|
|
61
60
|
Inputs:
|
|
62
61
|
- **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
63
62
|
- **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
64
63
|
- **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
65
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16`
|
|
66
|
-
|
|
64
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
|
|
65
|
+
[batch_size, seq_length, seq_length]): A matrix to pass masked information.
|
|
67
66
|
|
|
68
67
|
Outputs:
|
|
69
68
|
A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
|
|
@@ -92,35 +91,55 @@ class FlashAttention(Cell):
|
|
|
92
91
|
|
|
93
92
|
def __init__(self,
|
|
94
93
|
head_dim,
|
|
94
|
+
head_num,
|
|
95
95
|
dropout_rate=0.0,
|
|
96
96
|
prev_block_num=65536,
|
|
97
97
|
next_block_num=65536,
|
|
98
|
-
tiling_stgy_name="sparse",
|
|
99
98
|
dp=1,
|
|
100
99
|
mp=1,
|
|
101
100
|
high_precision=False,
|
|
102
101
|
have_attention_mask_batch=True,
|
|
103
|
-
alibi=False
|
|
102
|
+
alibi=False,
|
|
103
|
+
use_mqa=False
|
|
104
104
|
):
|
|
105
105
|
super(FlashAttention, self).__init__()
|
|
106
106
|
|
|
107
|
-
self.flash_attention = get_flash_attention(
|
|
108
|
-
prev_block_num=prev_block_num,
|
|
109
|
-
next_block_num=next_block_num,
|
|
110
|
-
tiling_stgy_name=tiling_stgy_name,
|
|
111
|
-
high_precision=high_precision
|
|
112
|
-
)
|
|
113
|
-
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
114
107
|
scaling_constant = math.sqrt(head_dim)
|
|
115
|
-
if scaling_constant
|
|
116
|
-
self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
|
|
117
|
-
else:
|
|
108
|
+
if scaling_constant == 0:
|
|
118
109
|
raise ValueError("the scaling constant must not be 0.")
|
|
119
|
-
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
120
|
-
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
121
110
|
self.dropout_rate = dropout_rate
|
|
122
|
-
self.have_attention_mask_batch = have_attention_mask_batch
|
|
123
111
|
self.alibi = alibi
|
|
112
|
+
self.have_attention_mask_batch = have_attention_mask_batch
|
|
113
|
+
|
|
114
|
+
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
115
|
+
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
116
|
+
self.reshape = ops.Reshape()
|
|
117
|
+
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
118
|
+
self.zeros = ops.Zeros()
|
|
119
|
+
self.attn_cast = ops.Cast()
|
|
120
|
+
if use_mqa:
|
|
121
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
122
|
+
(dp, 1, 1, 1),
|
|
123
|
+
(dp, 1, 1, 1))
|
|
124
|
+
else:
|
|
125
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
126
|
+
(dp, mp, 1, 1),
|
|
127
|
+
(dp, mp, 1, 1))
|
|
128
|
+
if self.alibi:
|
|
129
|
+
self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
130
|
+
self.alibi_rescale_factor = Tensor([scaling_constant], dtype=mstype.float16)
|
|
131
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
132
|
+
if dropout_rate > 1e-5:
|
|
133
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
134
|
+
fa_strategies += ((dp, 1, 1, 1),)
|
|
135
|
+
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
136
|
+
next_tokens=next_block_num,
|
|
137
|
+
keep_prob=1 - dropout_rate,
|
|
138
|
+
scale_value=1. / scaling_constant,
|
|
139
|
+
inner_precise=0,
|
|
140
|
+
input_layout="BNSD").shard(fa_strategies)
|
|
141
|
+
|
|
142
|
+
self.dropout_rate = dropout_rate
|
|
124
143
|
if self.dropout_rate > 1e-5:
|
|
125
144
|
self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
|
|
126
145
|
self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
|
|
@@ -136,49 +155,7 @@ class FlashAttention(Cell):
|
|
|
136
155
|
such as MatMul. Default: None.
|
|
137
156
|
:return:
|
|
138
157
|
"""
|
|
139
|
-
|
|
140
|
-
shard_stgy = list(in_strategy)
|
|
141
|
-
shard_stgy.insert(3, (1,)) # dim_mask
|
|
142
|
-
shard_stgy = tuple(shard_stgy)
|
|
143
|
-
else:
|
|
144
|
-
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
145
|
-
shard_stgy = (
|
|
146
|
-
(1, 1, 1, 1),
|
|
147
|
-
(1, 1, 1, 1),
|
|
148
|
-
(1, 1, 1, 1),
|
|
149
|
-
(1,), # dim_mask
|
|
150
|
-
)
|
|
151
|
-
self.flash_attention.shard(shard_stgy)
|
|
152
|
-
dp = shard_stgy[0][0]
|
|
153
|
-
mp = shard_stgy[0][1]
|
|
154
|
-
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
155
|
-
inputs_tensor_map = [
|
|
156
|
-
[3, 2, 1, 0],
|
|
157
|
-
[3, 2, 1, 0],
|
|
158
|
-
[3, 2, 1, 0],
|
|
159
|
-
[-1]
|
|
160
|
-
]
|
|
161
|
-
if self.have_attention_mask_batch:
|
|
162
|
-
inputs_tensor_map.append([3, 1, 0])
|
|
163
|
-
else:
|
|
164
|
-
inputs_tensor_map.append([-1, 1, 0])
|
|
165
|
-
|
|
166
|
-
# dropout_mask
|
|
167
|
-
if self.dropout_rate > 1e-5:
|
|
168
|
-
inputs_tensor_map.append([3, 2, 1, 0])
|
|
169
|
-
|
|
170
|
-
if self.alibi:
|
|
171
|
-
inputs_tensor_map.append([3, 2, 1, 0])
|
|
172
|
-
|
|
173
|
-
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
174
|
-
|
|
175
|
-
self.flash_attention.add_prim_attr("outputs_tensor_map", [
|
|
176
|
-
[3, 2, 1, 0], # O
|
|
177
|
-
[3, 2, 1], # L
|
|
178
|
-
[3, 2, 1] # M
|
|
179
|
-
])
|
|
180
|
-
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
181
|
-
self.flash_attention.add_prim_attr("empty_mirror_ops", 1)
|
|
158
|
+
self.flash_attention.shard(in_strategy)
|
|
182
159
|
|
|
183
160
|
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
184
161
|
"""FlashAttention forward
|
|
@@ -189,35 +166,24 @@ class FlashAttention(Cell):
|
|
|
189
166
|
:param alibi_mask: [bsz, head_num, 1, seq_len], if not None
|
|
190
167
|
:return: output [bsz, head_num, seq_len, head_dim]
|
|
191
168
|
"""
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
_, k_head_num, k_seq_len, _ = key.shape
|
|
195
|
-
_, v_head_num, v_seq_len, _ = value.shape
|
|
196
|
-
if head_num != k_head_num or head_num != v_head_num:
|
|
197
|
-
raise ValueError(
|
|
198
|
-
"the head_num of query, key and value must be the same, "
|
|
199
|
-
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
200
|
-
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
201
|
-
raise ValueError(
|
|
202
|
-
"query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
|
|
169
|
+
bsz, head_num, seq_len, _ = query.shape
|
|
170
|
+
# 910B -- FlashAttentionScore
|
|
203
171
|
if self.dropout_rate > 1e-5:
|
|
204
|
-
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
205
|
-
|
|
206
|
-
ones = self.fill_v2(tensor_shape, self.tensor_one)
|
|
207
|
-
ones = self.depend(ones, query)
|
|
208
|
-
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
209
|
-
else:
|
|
210
|
-
drop_mask = None
|
|
211
|
-
if head_dim > 304:
|
|
212
|
-
raise ValueError(
|
|
213
|
-
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
214
|
-
if head_dim % 16 != 0:
|
|
215
|
-
padding_size = 16 - head_dim % 16
|
|
216
|
-
query = mnp.pad(query, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
|
|
217
|
-
key = mnp.pad(key, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
|
|
218
|
-
value = mnp.pad(value, ((0, 0), (0, 0), (0, 0), (0, padding_size)), constant_values=0)
|
|
219
|
-
output, _, _ = self.flash_attention(query, key, value, self.dim_mask, attn_mask, drop_mask, alibi_mask)
|
|
220
|
-
output = ops.slice(output, [0, 0, 0, 0], [bsz, head_num, seq_len, head_dim])
|
|
172
|
+
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
173
|
+
(bsz, head_num, seq_len, seq_len // 8))
|
|
221
174
|
else:
|
|
222
|
-
|
|
175
|
+
drop_mask_bits = None
|
|
176
|
+
if self.alibi:
|
|
177
|
+
alibi_mask = self.alibi_rescale_mul(alibi_mask, self.cast(self.alibi_rescale_factor, alibi_mask.dtype))
|
|
178
|
+
# (B, S, S) -> (B, 1, S, S)
|
|
179
|
+
if self.have_attention_mask_batch:
|
|
180
|
+
attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
|
|
181
|
+
_, _, _, output = self.flash_attention(query,
|
|
182
|
+
key,
|
|
183
|
+
value,
|
|
184
|
+
alibi_mask,
|
|
185
|
+
drop_mask_bits,
|
|
186
|
+
None,
|
|
187
|
+
attn_mask,
|
|
188
|
+
None)
|
|
223
189
|
return output
|
mindspore/nn/layer/image.py
CHANGED
|
@@ -83,17 +83,17 @@ class ImageGradients(Cell):
|
|
|
83
83
|
_check_input_4d(F.shape(images), "images", self.cls_name)
|
|
84
84
|
batch_size, depth, height, width = P.Shape()(images)
|
|
85
85
|
if height == 1:
|
|
86
|
-
dy =
|
|
86
|
+
dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
87
87
|
else:
|
|
88
88
|
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
|
|
89
|
-
dy_last =
|
|
89
|
+
dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
90
90
|
dy = P.Concat(2)((dy, dy_last))
|
|
91
91
|
|
|
92
92
|
if width == 1:
|
|
93
|
-
dx =
|
|
93
|
+
dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
|
|
94
94
|
else:
|
|
95
95
|
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
|
|
96
|
-
dx_last =
|
|
96
|
+
dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
|
|
97
97
|
dx = P.Concat(3)((dx, dx_last))
|
|
98
98
|
return dy, dx
|
|
99
99
|
|
|
@@ -571,7 +571,8 @@ class PixelShuffle(Cell):
|
|
|
571
571
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
572
572
|
|
|
573
573
|
Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
|
|
574
|
-
:math:`(*, C, H \times r, W \times r)`,
|
|
574
|
+
:math:`(*, C, H \times r, W \times r)`,
|
|
575
|
+
where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions.
|
|
575
576
|
|
|
576
577
|
Note:
|
|
577
578
|
The dimension of input Tensor on Ascend should be less than 7.
|
|
@@ -621,7 +622,8 @@ class PixelUnshuffle(Cell):
|
|
|
621
622
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
622
623
|
|
|
623
624
|
Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
|
|
624
|
-
:math:`(*, C \times r^2, H, W)` ,
|
|
625
|
+
:math:`(*, C \times r^2, H, W)` ,
|
|
626
|
+
where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions.
|
|
625
627
|
|
|
626
628
|
Args:
|
|
627
629
|
downscale_factor (int): factor to unshuffle the input, and is a positive integer.
|
mindspore/nn/layer/math.py
CHANGED
|
@@ -223,7 +223,6 @@ class LGamma(Cell):
|
|
|
223
223
|
self.abs = P.Abs()
|
|
224
224
|
self.shape = P.Shape()
|
|
225
225
|
self.dtype = P.DType()
|
|
226
|
-
self.fill = P.Fill()
|
|
227
226
|
self.floor = P.Floor()
|
|
228
227
|
self.equal = P.Equal()
|
|
229
228
|
self.greater = P.Greater()
|
|
@@ -240,7 +239,7 @@ class LGamma(Cell):
|
|
|
240
239
|
if F.is_sequence_value_unknown(self.shape(x)):
|
|
241
240
|
infinity = self.ones_like(x) * F.cast(self.inf, input_dtype)
|
|
242
241
|
else:
|
|
243
|
-
infinity =
|
|
242
|
+
infinity = F.fill(input_dtype, self.shape(x), self.inf)
|
|
244
243
|
|
|
245
244
|
need_to_reflect = self.less(x, 0.5)
|
|
246
245
|
neg_input = -x
|
|
@@ -335,7 +334,6 @@ class DiGamma(Cell):
|
|
|
335
334
|
self.abs = P.Abs()
|
|
336
335
|
self.shape = P.Shape()
|
|
337
336
|
self.dtype = P.DType()
|
|
338
|
-
self.fill = P.Fill()
|
|
339
337
|
self.floor = P.Floor()
|
|
340
338
|
self.equal = P.Equal()
|
|
341
339
|
self.less = P.Less()
|
|
@@ -371,15 +369,12 @@ class DiGamma(Cell):
|
|
|
371
369
|
reduced_input = x + self.abs(self.floor(x + 0.5))
|
|
372
370
|
reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input)
|
|
373
371
|
real_result = self.select(need_to_reflect, reflection, y)
|
|
374
|
-
nan =
|
|
372
|
+
nan = F.fill(self.dtype(x), self.shape(x), np.nan)
|
|
375
373
|
|
|
376
374
|
return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))),
|
|
377
375
|
nan, real_result)
|
|
378
376
|
|
|
379
377
|
|
|
380
|
-
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
381
|
-
|
|
382
|
-
|
|
383
378
|
def _while_helper_func(cond, body, vals):
|
|
384
379
|
while cond(vals).any():
|
|
385
380
|
vals = body(vals)
|
|
@@ -391,13 +386,12 @@ def _igamma_series(ax, x, a, enabled):
|
|
|
391
386
|
|
|
392
387
|
logicaland = P.LogicalAnd()
|
|
393
388
|
greater = P.Greater()
|
|
394
|
-
fill = P.Fill()
|
|
395
389
|
shape = P.Shape()
|
|
396
390
|
dtype = P.DType()
|
|
397
391
|
select = P.Select()
|
|
398
392
|
|
|
399
393
|
# If more data types are supported, this epsilon need to be selected.
|
|
400
|
-
epsilon =
|
|
394
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
401
395
|
|
|
402
396
|
def cond(vals):
|
|
403
397
|
enabled = vals[0]
|
|
@@ -424,8 +418,8 @@ def _igamma_series(ax, x, a, enabled):
|
|
|
424
418
|
select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
|
|
425
419
|
select(enabled, dans_da, vals[6]))
|
|
426
420
|
|
|
427
|
-
ones = fill(dtype(a), shape(a), 1)
|
|
428
|
-
zeros = fill(dtype(a), shape(a), 0)
|
|
421
|
+
ones = F.fill(dtype(a), shape(a), 1)
|
|
422
|
+
zeros = F.fill(dtype(a), shape(a), 0)
|
|
429
423
|
vals = (enabled, a, ones, ones, x, zeros, zeros)
|
|
430
424
|
|
|
431
425
|
vals = _while_helper_func(cond, body, vals)
|
|
@@ -441,13 +435,12 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
441
435
|
greater = P.Greater()
|
|
442
436
|
less = P.Less()
|
|
443
437
|
notequal = P.NotEqual()
|
|
444
|
-
fill = P.Fill()
|
|
445
438
|
shape = P.Shape()
|
|
446
439
|
dtype = P.DType()
|
|
447
440
|
select = P.Select()
|
|
448
441
|
|
|
449
442
|
# If more data types are supported, this epsilon need to be selected.
|
|
450
|
-
epsilon =
|
|
443
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
451
444
|
|
|
452
445
|
def cond(vals):
|
|
453
446
|
enabled = vals[0]
|
|
@@ -482,7 +475,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
482
475
|
qk_is_nonzero = notequal(qk, 0)
|
|
483
476
|
r = pk / qk
|
|
484
477
|
|
|
485
|
-
t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
|
|
478
|
+
t = select(qk_is_nonzero, abs_x((ans - r) / r), F.fill(dtype(t), shape(t), 1))
|
|
486
479
|
ans = select(qk_is_nonzero, r, ans)
|
|
487
480
|
|
|
488
481
|
dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
|
|
@@ -490,7 +483,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
490
483
|
dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
|
|
491
484
|
grad_conditional = select(qk_is_nonzero,
|
|
492
485
|
abs_x(dans_da_new - dans_da),
|
|
493
|
-
fill(dtype(dans_da), shape(dans_da), 1))
|
|
486
|
+
F.fill(dtype(dans_da), shape(dans_da), 1))
|
|
494
487
|
|
|
495
488
|
pkm2 = pkm1
|
|
496
489
|
pkm1 = pk
|
|
@@ -525,16 +518,16 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
525
518
|
|
|
526
519
|
y = 1 - a
|
|
527
520
|
z = x + y + 1
|
|
528
|
-
c = fill(dtype(x), shape(x), 0)
|
|
529
|
-
pkm2 = fill(dtype(x), shape(x), 1)
|
|
521
|
+
c = F.fill(dtype(x), shape(x), 0)
|
|
522
|
+
pkm2 = F.fill(dtype(x), shape(x), 1)
|
|
530
523
|
qkm2 = x
|
|
531
524
|
pkm1 = x + 1
|
|
532
525
|
qkm1 = z * x
|
|
533
526
|
ans = pkm1 / qkm1
|
|
534
|
-
t = fill(dtype(x), shape(x), 1)
|
|
535
|
-
dpkm2_da = fill(dtype(x), shape(x), 0)
|
|
536
|
-
dqkm2_da = fill(dtype(x), shape(x), 0)
|
|
537
|
-
dpkm1_da = fill(dtype(x), shape(x), 0)
|
|
527
|
+
t = F.fill(dtype(x), shape(x), 1)
|
|
528
|
+
dpkm2_da = F.fill(dtype(x), shape(x), 0)
|
|
529
|
+
dqkm2_da = F.fill(dtype(x), shape(x), 0)
|
|
530
|
+
dpkm1_da = F.fill(dtype(x), shape(x), 0)
|
|
538
531
|
dqkm1_da = -x
|
|
539
532
|
dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
|
|
540
533
|
vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
|
|
@@ -606,7 +599,6 @@ class IGamma(Cell):
|
|
|
606
599
|
self.exp = P.Exp()
|
|
607
600
|
self.select = P.Select()
|
|
608
601
|
self.zeroslike = P.ZerosLike()
|
|
609
|
-
self.fill = P.Fill()
|
|
610
602
|
self.shape = P.Shape()
|
|
611
603
|
self.dtype = P.DType()
|
|
612
604
|
self.lgamma = LGamma()
|
|
@@ -625,15 +617,14 @@ class IGamma(Cell):
|
|
|
625
617
|
x = F.broadcast_to(x, para_shape)
|
|
626
618
|
a = F.broadcast_to(a, para_shape)
|
|
627
619
|
x_is_zero = self.equal(x, 0)
|
|
628
|
-
|
|
629
|
-
underflow = self.less(ax, self.neg(log_maxfloat))
|
|
620
|
+
underflow = self.less(ax, self.neg(self.log_maxfloat32))
|
|
630
621
|
ax = self.exp(ax)
|
|
631
622
|
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
|
|
632
623
|
output = self.select(use_igammac,
|
|
633
624
|
1 - _igammac_continued_fraction(ax, x, a, self.logicaland(enabled, use_igammac)),
|
|
634
625
|
_igamma_series(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
|
|
635
626
|
output = self.select(x_is_zero, self.zeroslike(output), output)
|
|
636
|
-
output = self.select(domain_error,
|
|
627
|
+
output = self.select(domain_error, F.fill(self.dtype(a), self.shape(a), np.nan), output)
|
|
637
628
|
return output
|
|
638
629
|
|
|
639
630
|
|