mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +24 -193
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +97 -74
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +1915 -3287
- mindspore/common/api.py +341 -354
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +297 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +214 -560
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +108 -76
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +93 -144
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +131 -700
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +194 -109
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +218 -24
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1250 -176
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +16 -12
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/math_ops.py +4 -4
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +7 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
- mindspore/ops/auto_generate/gen_extend_func.py +281 -135
- mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
- mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1629 -2345
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3035 -3705
- mindspore/ops/function/nn_func.py +676 -241
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +204 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +6 -4
- mindspore/ops/functional_overload.py +547 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +10 -5
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +37 -22
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +221 -23
- mindspore/ops/operations/debug_ops.py +115 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +65 -191
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +232 -13
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +133 -6
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +656 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -61
- mindspore/parallel/transform_safetensors.py +287 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +25 -8
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +35 -7
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +176 -103
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""compile custom kernel with ninja"""
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
import shlex
|
|
20
|
+
import subprocess
|
|
21
|
+
import sysconfig
|
|
22
|
+
import time
|
|
23
|
+
import stat
|
|
24
|
+
from mindspore import log as logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VersionManager:
|
|
28
|
+
"""version manager"""
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self.entries = {} # module_name : (version, hash)
|
|
32
|
+
|
|
33
|
+
def _get_version(self, module_name):
|
|
34
|
+
"""get version"""
|
|
35
|
+
return self.entries.get(module_name, (None, None))[0]
|
|
36
|
+
|
|
37
|
+
def _update_version_if_changed(self, module_name, sources, build_args, build_dir):
|
|
38
|
+
"""update version if changed"""
|
|
39
|
+
hash_value = self._update_hash(0, build_dir)
|
|
40
|
+
hash_value = self._update_sources_hash(hash_value, sources)
|
|
41
|
+
hash_value = self._update_args_hash(hash_value, build_args)
|
|
42
|
+
|
|
43
|
+
entry = self.entries.get(module_name)
|
|
44
|
+
if entry is None:
|
|
45
|
+
self.entries[module_name] = entry = (0, hash_value)
|
|
46
|
+
elif hash_value != entry[1]:
|
|
47
|
+
self.entries[module_name] = entry = (entry[0] + 1, hash_value)
|
|
48
|
+
|
|
49
|
+
return entry[0]
|
|
50
|
+
|
|
51
|
+
def _update_hash(self, seed, value):
|
|
52
|
+
"""update hash value"""
|
|
53
|
+
# Good old boost::hash_combine
|
|
54
|
+
return seed ^ (hash(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2))
|
|
55
|
+
|
|
56
|
+
def _update_sources_hash(self, hash_value, sources):
|
|
57
|
+
"""hash source files"""
|
|
58
|
+
for filename in sources:
|
|
59
|
+
with open(filename) as file:
|
|
60
|
+
hash_value = self._update_hash(hash_value, file.read())
|
|
61
|
+
return hash_value
|
|
62
|
+
|
|
63
|
+
def _update_args_hash(self, hash_value, build_args):
|
|
64
|
+
"""hash build arguments"""
|
|
65
|
+
for group in build_args:
|
|
66
|
+
if group:
|
|
67
|
+
for argument in group:
|
|
68
|
+
hash_value = self._update_hash(hash_value, argument)
|
|
69
|
+
return hash_value
|
|
70
|
+
|
|
71
|
+
def check_version(self, name, sources, cflags, ldflags, include_paths, build_dir):
|
|
72
|
+
"""check version"""
|
|
73
|
+
old_version = self._get_version(name)
|
|
74
|
+
version = self._update_version_if_changed(name, sources, [cflags, ldflags, include_paths], build_dir)
|
|
75
|
+
logger.info(f'Build module {name}, version={version}')
|
|
76
|
+
if version > 0:
|
|
77
|
+
if version != old_version:
|
|
78
|
+
logger.info(
|
|
79
|
+
f'The conditions for extension module {name} have changed. '
|
|
80
|
+
f'Updating to version {version} and re-building as {name}_v{version}.'
|
|
81
|
+
)
|
|
82
|
+
name = f'{name}_v{version}'
|
|
83
|
+
|
|
84
|
+
if version != old_version:
|
|
85
|
+
return True
|
|
86
|
+
logger.info(f'No modifications detected for extension module {name}')
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
version_manager = VersionManager()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class FileLocker:
|
|
94
|
+
"""FileLocker"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, build_dir):
|
|
97
|
+
"""FileLocker"""
|
|
98
|
+
self.lock_file_name = os.path.join(build_dir, 'build.lock')
|
|
99
|
+
self.lock_fd = None
|
|
100
|
+
|
|
101
|
+
def try_lock(self):
|
|
102
|
+
"""Acquire a file-based lock."""
|
|
103
|
+
try:
|
|
104
|
+
mode = stat.S_IRUSR | stat.S_IWUSR
|
|
105
|
+
self.lock_fd = os.open(self.lock_file_name, os.O_CREAT | os.O_EXCL, mode)
|
|
106
|
+
return True
|
|
107
|
+
except FileExistsError:
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def release_lock(self):
|
|
111
|
+
"""Release the file-based lock."""
|
|
112
|
+
if self.lock_fd is not None:
|
|
113
|
+
os.close(self.lock_fd)
|
|
114
|
+
self.lock_fd = None
|
|
115
|
+
os.remove(self.lock_file_name)
|
|
116
|
+
|
|
117
|
+
def wait(self):
|
|
118
|
+
"""Wait until lock is released."""
|
|
119
|
+
while os.path.exists(self.lock_file_name):
|
|
120
|
+
time.sleep(0.5)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class ExtensionBuilder:
|
|
124
|
+
"""ExtensionBuilder"""
|
|
125
|
+
|
|
126
|
+
def __init__(self):
|
|
127
|
+
"""ExtensionBuilder"""
|
|
128
|
+
|
|
129
|
+
def _get_build_directory(self, module_name):
|
|
130
|
+
"""Get build directory."""
|
|
131
|
+
build_root = os.environ.get('MS_COMPILER_CACHE_PATH')
|
|
132
|
+
if build_root is None:
|
|
133
|
+
build_root = os.path.realpath("./kernel_meta")
|
|
134
|
+
logger.info(f'Using {build_root} as MindSpore extensions root...')
|
|
135
|
+
|
|
136
|
+
build_dir = os.path.join(build_root, module_name)
|
|
137
|
+
if not os.path.exists(build_dir):
|
|
138
|
+
os.makedirs(build_dir, exist_ok=True)
|
|
139
|
+
return build_dir
|
|
140
|
+
|
|
141
|
+
def _compile(self, name, sources, cflags, ldflags, include_paths, build_dir):
|
|
142
|
+
"""Compile."""
|
|
143
|
+
if version_manager.check_version(name, sources, cflags, ldflags, include_paths, build_dir):
|
|
144
|
+
locker = FileLocker(build_dir)
|
|
145
|
+
if locker.try_lock():
|
|
146
|
+
try:
|
|
147
|
+
self._write_ninja_file_and_build_library(name, sources, cflags, ldflags, include_paths, build_dir)
|
|
148
|
+
finally:
|
|
149
|
+
locker.release_lock()
|
|
150
|
+
else:
|
|
151
|
+
locker.wait()
|
|
152
|
+
logger.info(f'Loading extension module {name}...')
|
|
153
|
+
|
|
154
|
+
def _verify_ninja_availability(self):
|
|
155
|
+
"""Check ninja is available."""
|
|
156
|
+
try:
|
|
157
|
+
subprocess.check_output('ninja --version'.split())
|
|
158
|
+
except Exception:
|
|
159
|
+
raise RuntimeError("Ninja is required to load C++ extensions")
|
|
160
|
+
|
|
161
|
+
def _write_ninja_file_and_build_library(self, module_name, sources, cflags, ldflags, include_paths, build_dir):
|
|
162
|
+
"""Write ninja file and build library."""
|
|
163
|
+
self._verify_ninja_availability()
|
|
164
|
+
|
|
165
|
+
ninja_build_file = os.path.join(build_dir, 'build.ninja')
|
|
166
|
+
logger.info(f'Save ninja build file {ninja_build_file}.')
|
|
167
|
+
self._write_ninja_file(ninja_build_file, module_name, sources, cflags, ldflags, include_paths)
|
|
168
|
+
|
|
169
|
+
logger.info(f'Building extension module {module_name}.')
|
|
170
|
+
self._run_ninja_build(build_dir, module_name)
|
|
171
|
+
|
|
172
|
+
def _write_ninja_file(self, fname, name, sources, extra_cflags, extra_ldflags, extra_include_paths):
|
|
173
|
+
"""Write ninja file."""
|
|
174
|
+
python_include_path = sysconfig.get_path('include', scheme='posix_prefix')
|
|
175
|
+
python_includes = [python_include_path] if python_include_path is not None else []
|
|
176
|
+
cflags = [f'-DMS_EXTENSION_NAME={name}', "-D_GLIBCXX_USE_CXX11_ABI=0"]
|
|
177
|
+
cflags += [f'-I{shlex.quote(os.path.abspath(include.strip()))}' for include in extra_include_paths]
|
|
178
|
+
cflags += [f'-isystem {shlex.quote(include)}' for include in python_includes]
|
|
179
|
+
cflags += ['-fPIC', '-std=c++17']
|
|
180
|
+
cflags += extra_cflags
|
|
181
|
+
cflags = [flag.strip() for flag in cflags]
|
|
182
|
+
|
|
183
|
+
# '/path/to/file.cpp' -> 'file'
|
|
184
|
+
objs = [os.path.splitext(os.path.basename(src))[0] + ".o" for src in sources]
|
|
185
|
+
sources = [os.path.abspath(file) for file in sources]
|
|
186
|
+
ldflags = ['-shared'] + [flag.strip() for flag in extra_ldflags]
|
|
187
|
+
target = name + '.so'
|
|
188
|
+
|
|
189
|
+
config = ['ninja_required_version = 1.3']
|
|
190
|
+
config.append('cxx = ' + os.environ.get('CXX', 'g++'))
|
|
191
|
+
|
|
192
|
+
flags = [f'cflags = {" ".join(cflags)}']
|
|
193
|
+
flags.append(f'ldflags = {" ".join(ldflags)}')
|
|
194
|
+
|
|
195
|
+
compile_rule = ['rule compile']
|
|
196
|
+
compile_rule.append(' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out')
|
|
197
|
+
compile_rule.append(' depfile = $out.d')
|
|
198
|
+
compile_rule.append(' deps = gcc')
|
|
199
|
+
|
|
200
|
+
build = [f'build {obj.replace(" ", "$ ")}: compile {src.replace(" ", "$ ")}' for src, obj in zip(sources, objs)]
|
|
201
|
+
|
|
202
|
+
link_rule = ['rule link', ' command = $cxx $in $ldflags -o $out']
|
|
203
|
+
link = [f'build {target}: link {" ".join(objs)}']
|
|
204
|
+
default = [f'default {target}']
|
|
205
|
+
|
|
206
|
+
blocks = [config, flags, compile_rule, link_rule, build, link, default]
|
|
207
|
+
content = "\n\n".join("\n".join(b) for b in blocks) + "\n"
|
|
208
|
+
|
|
209
|
+
if os.path.exists(fname):
|
|
210
|
+
with open(fname) as f:
|
|
211
|
+
old_content = f.read()
|
|
212
|
+
if old_content == content:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
with open(fname, 'w') as source_file:
|
|
216
|
+
source_file.write(content)
|
|
217
|
+
|
|
218
|
+
def _run_ninja_build(self, build_dir, module_name):
|
|
219
|
+
"""Run ninja build."""
|
|
220
|
+
cmd = ['ninja', '-v']
|
|
221
|
+
env = os.environ.copy()
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=build_dir, check=True, env=env)
|
|
225
|
+
# If the build succeeds, do nothing with the output (silent)
|
|
226
|
+
except subprocess.CalledProcessError as e:
|
|
227
|
+
# Capture the error details
|
|
228
|
+
stderr_output = e.stderr.decode() if e.stderr else ""
|
|
229
|
+
stdout_output = e.stdout.decode() if e.stdout else ""
|
|
230
|
+
full_output = stderr_output + stdout_output
|
|
231
|
+
|
|
232
|
+
# Format the error message
|
|
233
|
+
msg = f"Error building extension '{module_name}': {full_output}"
|
|
234
|
+
|
|
235
|
+
# In multi-card situation, only one process build the library.
|
|
236
|
+
# When building failed, the old extension library should be removed.
|
|
237
|
+
so_file = os.path.join(build_dir, f"{module_name}.so")
|
|
238
|
+
if os.path.exists(so_file):
|
|
239
|
+
os.remove(so_file)
|
|
240
|
+
raise RuntimeError(msg) from e
|
|
241
|
+
|
|
242
|
+
def build(self, module_name, sources, extra_cflags=None, extra_ldflags=None, extra_include_paths=None):
|
|
243
|
+
"""Build module."""
|
|
244
|
+
src = [sources] if isinstance(sources, str) else sources
|
|
245
|
+
build_dir = self._get_build_directory(module_name)
|
|
246
|
+
self._compile(module_name, src, extra_cflags, extra_ldflags, extra_include_paths, build_dir)
|
|
247
|
+
return os.path.join(build_dir, f"{module_name}.so")
|
|
@@ -36,7 +36,7 @@ from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad,
|
|
|
36
36
|
ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
|
|
37
37
|
HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
|
|
38
38
|
FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad, MaskedSelectGrad,
|
|
39
|
-
BinaryCrossEntropyGrad, SoftShrinkGrad, SeluGrad, SmoothL1LossGrad)
|
|
39
|
+
BinaryCrossEntropyGrad, SoftShrinkGrad, SoftMarginLossGrad, SeluGrad, SmoothL1LossGrad)
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class SparseFillEmptyRowsGrad(Primitive):
|
|
@@ -1639,15 +1639,6 @@ class SliceGrad(PrimitiveWithInfer):
|
|
|
1639
1639
|
'value': None}
|
|
1640
1640
|
|
|
1641
1641
|
|
|
1642
|
-
class SoftMarginLossGrad(Primitive):
|
|
1643
|
-
"""Computes gradient for prediction on SoftMarginLoss."""
|
|
1644
|
-
|
|
1645
|
-
@prim_attr_register
|
|
1646
|
-
def __init__(self, reduction="mean"):
|
|
1647
|
-
self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient'])
|
|
1648
|
-
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
1642
|
class StridedSliceGrad(Primitive):
|
|
1652
1643
|
"""
|
|
1653
1644
|
Performs grad of StridedSlice operation.
|
|
@@ -21,7 +21,6 @@ import weakref
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
23
|
from mindspore.common import Tensor
|
|
24
|
-
from mindspore.common._stub_tensor import StubTensor
|
|
25
24
|
from mindspore.ops import composite as C
|
|
26
25
|
from mindspore.ops.operations.array_ops import Cast
|
|
27
26
|
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
|
@@ -29,7 +28,7 @@ from mindspore.ops import signature as sig
|
|
|
29
28
|
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
30
29
|
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
31
30
|
_run_op, _check_contains_variable
|
|
32
|
-
from mindspore._c_expression import
|
|
31
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
33
32
|
from mindspore._c_expression import typing, HookType
|
|
34
33
|
from mindspore._c_expression import pyboost_generator
|
|
35
34
|
from mindspore import _checkparam as validator
|
|
@@ -38,8 +37,6 @@ from mindspore.common.parameter import Parameter
|
|
|
38
37
|
from mindspore.common._stub_tensor import _convert_stub
|
|
39
38
|
from mindspore.communication.management import GlobalComm, get_rank, _get_group, get_group_size
|
|
40
39
|
from mindspore.common.api import _pynative_executor
|
|
41
|
-
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
42
|
-
from mindspore import ops
|
|
43
40
|
from ..auto_generate import TensorCopySlices, SiLU, Cummin, TopKRouter, ExtractImagePatches, DecoderKVCache, \
|
|
44
41
|
PromptKVCache, ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4
|
|
45
42
|
|
|
@@ -668,17 +665,17 @@ class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
|
|
|
668
665
|
|
|
669
666
|
class SequenceMask(PrimitiveWithCheck):
|
|
670
667
|
"""
|
|
671
|
-
Returns a mask tensor representing the first N positions of each cell.
|
|
668
|
+
Returns a mask tensor representing the first N positions of each cell. The internal element data type is bool.
|
|
672
669
|
|
|
673
670
|
If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type and shape
|
|
674
671
|
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
|
|
675
672
|
|
|
676
673
|
Inputs:
|
|
677
|
-
- **lengths** (Tensor) -
|
|
674
|
+
- **lengths** (Tensor) - The input tensor. All values in this tensor should be
|
|
678
675
|
less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
|
|
679
676
|
Must be type int32 or int64.
|
|
680
677
|
|
|
681
|
-
- **maxlen** (int) -
|
|
678
|
+
- **maxlen** (int) - Specify the length of the returned tensor. Must be positive and same
|
|
682
679
|
type as elements in `lengths`.
|
|
683
680
|
|
|
684
681
|
Outputs:
|
|
@@ -2256,74 +2253,6 @@ class IsInstance(PrimitiveWithInfer):
|
|
|
2256
2253
|
return out
|
|
2257
2254
|
|
|
2258
2255
|
|
|
2259
|
-
class ConvertToAdapterTensor(Primitive):
|
|
2260
|
-
"""
|
|
2261
|
-
Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type,
|
|
2262
|
-
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
|
|
2263
|
-
|
|
2264
|
-
Inputs:
|
|
2265
|
-
- **x** (Tensor) - The input tensor.
|
|
2266
|
-
|
|
2267
|
-
Outputs:
|
|
2268
|
-
A tensor, whose type is MSAdapter's Tensor.
|
|
2269
|
-
|
|
2270
|
-
Supported Platforms:
|
|
2271
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2272
|
-
|
|
2273
|
-
Examples:
|
|
2274
|
-
>>> x = Tensor([1, 2 ,3])
|
|
2275
|
-
>>> x = ops.ConvertToAdapterTensor()(x)
|
|
2276
|
-
>>> print(x)
|
|
2277
|
-
[1 2 3]
|
|
2278
|
-
"""
|
|
2279
|
-
|
|
2280
|
-
@prim_attr_register
|
|
2281
|
-
def __init__(self):
|
|
2282
|
-
"""Initialize"""
|
|
2283
|
-
|
|
2284
|
-
def __call__(self, x):
|
|
2285
|
-
"""Run in PyNative mode"""
|
|
2286
|
-
return ms_adapter_registry.tensor(x, cast_tensor=True)
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
convert_to_adapter_tensor = ConvertToAdapterTensor()
|
|
2290
|
-
|
|
2291
|
-
|
|
2292
|
-
class ConvertToMsTensor(Primitive):
|
|
2293
|
-
"""
|
|
2294
|
-
Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type,
|
|
2295
|
-
where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
|
|
2296
|
-
|
|
2297
|
-
Inputs:
|
|
2298
|
-
- **x** (Tensor) - The input tensor.
|
|
2299
|
-
|
|
2300
|
-
Outputs:
|
|
2301
|
-
A tensor, whose type is MindSpore's Tensor.
|
|
2302
|
-
|
|
2303
|
-
Supported Platforms:
|
|
2304
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2305
|
-
|
|
2306
|
-
Examples:
|
|
2307
|
-
>>> x = Tensor([1, 2 ,3])
|
|
2308
|
-
>>> x = ops.ConvertToMsTensor()(x)
|
|
2309
|
-
>>> print(x)
|
|
2310
|
-
[1 2 3]
|
|
2311
|
-
"""
|
|
2312
|
-
|
|
2313
|
-
@prim_attr_register
|
|
2314
|
-
def __init__(self):
|
|
2315
|
-
"""Initialize"""
|
|
2316
|
-
|
|
2317
|
-
def __call__(self, x):
|
|
2318
|
-
"""Run in PyNative mode"""
|
|
2319
|
-
if isinstance(x, StubTensor):
|
|
2320
|
-
return StubTensor(stub=x.stub, tensor=x.tensor)
|
|
2321
|
-
return ops.auto_generate.deepcopy(x)
|
|
2322
|
-
|
|
2323
|
-
|
|
2324
|
-
convert_to_ms_tensor = ConvertToMsTensor()
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
2256
|
class GetGrad(Primitive):
|
|
2328
2257
|
"""
|
|
2329
2258
|
Use the position id or Parameter object to get the gradient from the output
|
|
@@ -2475,7 +2404,7 @@ class FFN(Primitive):
|
|
|
2475
2404
|
The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
|
|
2476
2405
|
|
|
2477
2406
|
Args:
|
|
2478
|
-
activation (
|
|
2407
|
+
activation (str): The activation type, set to 'fastgelu' or 'gelu'.
|
|
2479
2408
|
Only support 'fastgelu' for now. Default: "fastgelu".
|
|
2480
2409
|
inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
|
|
2481
2410
|
Only support 1 for now. Default: 0.
|
|
@@ -496,7 +496,7 @@ def kernel(fn=None, reg_info=None, compile_attrs=None):
|
|
|
496
496
|
callable function is equal to the case when `fn` is not None.
|
|
497
497
|
|
|
498
498
|
Supported Platforms:
|
|
499
|
-
``
|
|
499
|
+
``GPU`` ``CPU``
|
|
500
500
|
|
|
501
501
|
Examples:
|
|
502
502
|
>>> import numpy as np
|
|
@@ -510,12 +510,12 @@ def kernel(fn=None, reg_info=None, compile_attrs=None):
|
|
|
510
510
|
... "test3": 12,
|
|
511
511
|
... }
|
|
512
512
|
>>> # Create the reg info json string.
|
|
513
|
-
>>>
|
|
513
|
+
>>> op_cpu_info = CustomRegOp() \\
|
|
514
514
|
... .input(0, "a") \\
|
|
515
515
|
... .input(0, "b") \\
|
|
516
516
|
... .output(0, "y") \\
|
|
517
517
|
... .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \\
|
|
518
|
-
... .target("
|
|
518
|
+
... .target("CPU") \\
|
|
519
519
|
... .get_op_info()
|
|
520
520
|
>>>
|
|
521
521
|
>>> # Create inputs for the custom op.
|
|
@@ -524,7 +524,7 @@ def kernel(fn=None, reg_info=None, compile_attrs=None):
|
|
|
524
524
|
...
|
|
525
525
|
>>> # Write a Hybrid DSL function through the decorator @kernel.
|
|
526
526
|
>>> # We can also pass the compile attrs and the reg info through the decorator.
|
|
527
|
-
>>> @kernel(reg_info=
|
|
527
|
+
>>> @kernel(reg_info=op_cpu_info, compile_attrs=attrs)
|
|
528
528
|
... def outer_product(a, b):
|
|
529
529
|
... c = output_tensor(a.shape, a.dtype)
|
|
530
530
|
...
|
|
@@ -539,12 +539,6 @@ def kernel(fn=None, reg_info=None, compile_attrs=None):
|
|
|
539
539
|
>>> # We can use the function directly as a python function.
|
|
540
540
|
>>> # In this case, the inputs should be numpy arrays.
|
|
541
541
|
>>> result = outer_product(input_x, input_y)
|
|
542
|
-
...
|
|
543
|
-
>>> # Create a custom op with mode "hybrid" (default value) by the Hybrid DSL function.
|
|
544
|
-
>>> # In this case, we will enjoy the automatic dtype/shape infer for free.
|
|
545
|
-
>>> # The inputs should be mindspore tensors.
|
|
546
|
-
>>> test_op_hybrid = ops.Custom(outer_product)
|
|
547
|
-
>>> output = test_op_hybrid(Tensor(input_x), Tensor(input_y))
|
|
548
542
|
"""
|
|
549
543
|
if compile_attrs is None:
|
|
550
544
|
compile_attrs = {}
|
|
@@ -859,7 +859,7 @@ class TensorsQueueCreate(PrimitiveWithInfer):
|
|
|
859
859
|
dtype (mindspore.dtype): the data type in the TensorsQueue.
|
|
860
860
|
shapes (tuple(tuple(int))): the shape of each tensor in element.
|
|
861
861
|
size (int): The size of the TensorsQueue.
|
|
862
|
-
name (
|
|
862
|
+
name (str): the name of this TensorsQueue. Default: "Q".
|
|
863
863
|
|
|
864
864
|
Inputs:
|
|
865
865
|
None.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2023-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
from mindspore.ops.primitive import Primitive, prim_attr_register
|
|
18
18
|
from .manually_defined import ScalarAdd, ScalarBool, ScalarDiv, ScalarMul, ScalarEq, ScalarFloorDiv, ScalarGe, \
|
|
19
|
-
ScalarGt, ScalarLe, ScalarLog, ScalarLt, ScalarMod, ScalarPow, ScalarSub, ScalarUadd, ScalarUsub
|
|
19
|
+
ScalarGt, ScalarLe, ScalarLog, ScalarLt, ScalarMod, ScalarPow, ScalarSub, ScalarUadd, ScalarUsub, ScalarMax, \
|
|
20
|
+
ScalarMin
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class bool_not(Primitive):
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
from mindspore.ops.primitive import Primitive, PrimitiveWithCheck, prim_attr_register
|
|
17
17
|
import mindspore._checkparam as validator
|
|
18
18
|
from mindspore.common import Tensor
|
|
19
|
-
from mindspore._c_expression import
|
|
19
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class ListAppend(Primitive):
|
|
@@ -33,7 +33,7 @@ class TensorArray(PrimitiveWithInfer):
|
|
|
33
33
|
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
|
|
34
34
|
dynamic_size (bool): If true the TensorArray can increase the size. Default: ``True``.
|
|
35
35
|
size (int): The size of the TensorArray if dynamic_size = False.
|
|
36
|
-
name (
|
|
36
|
+
name (str): the name of this TensorArray. Default: "TA".
|
|
37
37
|
|
|
38
38
|
Inputs:
|
|
39
39
|
None.
|
|
@@ -32,7 +32,7 @@ from mindspore._checkparam import _check_3d_int_or_tuple
|
|
|
32
32
|
from mindspore.common import dtype as mstype
|
|
33
33
|
from mindspore.common._decorator import deprecated
|
|
34
34
|
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
35
|
-
from mindspore._c_expression import
|
|
35
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
36
36
|
from mindspore._c_expression import CSRTensor as CSRTensor_
|
|
37
37
|
from mindspore._c_expression import COOTensor as COOTensor_
|
|
38
38
|
from ..auto_generate import (
|
|
@@ -273,7 +273,7 @@ class Im2Col(Primitive):
|
|
|
273
273
|
each sliding `ksizes`- sized block within the spatial dimensions
|
|
274
274
|
of input `x` into a column (i.e., last dimension) of a 4-D output
|
|
275
275
|
tensor of shape :math:`(N, C, \prod(\text{kernel_size}), L)`, where
|
|
276
|
-
:math:`C \times \prod(\text{kernel_size})` is the total number of
|
|
276
|
+
:math:`C \times \prod(\text{kernel_size})` is the total number of elements
|
|
277
277
|
within each block (a block has :math:`\prod(\text{kernel_size})` spatial
|
|
278
278
|
locations each containing a `C`-channeled vector), and :math:`L` is
|
|
279
279
|
the total number of such blocks:
|
|
@@ -1169,7 +1169,7 @@ class TupleToArray(PrimitiveWithInfer):
|
|
|
1169
1169
|
|
|
1170
1170
|
Inputs:
|
|
1171
1171
|
- **input_x** (tuple) - A tuple of numbers. These numbers have the same type.
|
|
1172
|
-
The shape is :math:`(N
|
|
1172
|
+
The shape is :math:`(N,)`.
|
|
1173
1173
|
|
|
1174
1174
|
Outputs:
|
|
1175
1175
|
Tensor, if the input tuple contains `N` numbers, then the shape of the output tensor is :math:`(N,)`.
|
|
@@ -1802,9 +1802,9 @@ class Unstack(Primitive):
|
|
|
1802
1802
|
Refer to :func:`mindspore.ops.unstack` for more details.
|
|
1803
1803
|
|
|
1804
1804
|
Args:
|
|
1805
|
-
axis (int): Dimension along which to unpack. Default: ``0`` .
|
|
1805
|
+
axis (int, optional): Dimension along which to unpack. Default: ``0`` .
|
|
1806
1806
|
Negative values wrap around. The range is [-R, R).
|
|
1807
|
-
num (Union[None, int]): The number of output tensors.
|
|
1807
|
+
num (Union[None, int], optional): The number of output tensors.
|
|
1808
1808
|
Automatically inferred by input_x and axis if ``None`` . Default: ``None`` .
|
|
1809
1809
|
|
|
1810
1810
|
Inputs:
|
|
@@ -2113,7 +2113,7 @@ class ScatterNdUpdate(Primitive):
|
|
|
2113
2113
|
the relatively highest priority data type.
|
|
2114
2114
|
|
|
2115
2115
|
Args:
|
|
2116
|
-
use_locking (bool): Whether to protect the assignment by a lock. Default: ``True`` .
|
|
2116
|
+
use_locking (bool, optional): Whether to protect the assignment by a lock. Default: ``True`` .
|
|
2117
2117
|
|
|
2118
2118
|
Inputs:
|
|
2119
2119
|
- **input_x** (Union[Parameter, Tensor]) - The target tensor, with data type of Parameter or Tensor.
|
|
@@ -2244,7 +2244,7 @@ class ScatterMin(_ScatterOpDynamic):
|
|
|
2244
2244
|
when `updates` does not support conversion to the data type required by `input_x`.
|
|
2245
2245
|
|
|
2246
2246
|
Args:
|
|
2247
|
-
use_locking (bool): Whether to protect the assignment by a lock. Default: ``False`` .
|
|
2247
|
+
use_locking (bool, optional): Whether to protect the assignment by a lock. Default: ``False`` .
|
|
2248
2248
|
|
|
2249
2249
|
Inputs:
|
|
2250
2250
|
- **input_x** (Union[Parameter, Tensor]) - The target tensor, with data type of Parameter or Tensor.
|
|
@@ -2306,7 +2306,7 @@ class ScatterAdd(Primitive):
|
|
|
2306
2306
|
This is an in-place update operator. Therefore, the `input_x` will be updated after the operation is completed.
|
|
2307
2307
|
|
|
2308
2308
|
Args:
|
|
2309
|
-
use_locking (bool): Whether to protect the assignment by a lock.
|
|
2309
|
+
use_locking (bool, optional): Whether to protect the assignment by a lock.
|
|
2310
2310
|
If ``True`` , `input_x` will be protected by the lock.
|
|
2311
2311
|
Otherwise, the calculation result is undefined. Default: ``False`` .
|
|
2312
2312
|
|
|
@@ -2426,7 +2426,7 @@ class ScatterSub(Primitive):
|
|
|
2426
2426
|
the relatively highest priority data type.
|
|
2427
2427
|
|
|
2428
2428
|
Args:
|
|
2429
|
-
use_locking (bool): Whether to protect the assignment by a lock. Default: ``False`` .
|
|
2429
|
+
use_locking (bool, optional): Whether to protect the assignment by a lock. Default: ``False`` .
|
|
2430
2430
|
|
|
2431
2431
|
Inputs:
|
|
2432
2432
|
- **input_x** (Union[Parameter, Tensor]) - The target tensor, with data type of Parameter or Tensor.
|
|
@@ -3037,7 +3037,7 @@ class ScatterNdDiv(_ScatterNdOp):
|
|
|
3037
3037
|
|
|
3038
3038
|
class ScatterNdMax(_ScatterNdOp):
|
|
3039
3039
|
r"""
|
|
3040
|
-
|
|
3040
|
+
Computes sparse maximum to individual values or slices in a tensor.
|
|
3041
3041
|
|
|
3042
3042
|
Using given values to update parameter value through the maximum operation, along with the input indices.
|
|
3043
3043
|
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
|
@@ -3587,7 +3587,7 @@ class ReverseSequence(PrimitiveWithInfer):
|
|
|
3587
3587
|
|
|
3588
3588
|
Args:
|
|
3589
3589
|
seq_dim (int): The dimension where reversal is performed. Required.
|
|
3590
|
-
batch_dim (int): The input is sliced in this dimension. Default: ``0`` .
|
|
3590
|
+
batch_dim (int, optional): The input is sliced in this dimension. Default: ``0`` .
|
|
3591
3591
|
|
|
3592
3592
|
Inputs:
|
|
3593
3593
|
- **x** (Tensor) - The input to reverse, supporting all number types including bool.
|
|
@@ -3838,10 +3838,11 @@ class EmbeddingLookup(Primitive):
|
|
|
3838
3838
|
`offset`.
|
|
3839
3839
|
|
|
3840
3840
|
Inputs:
|
|
3841
|
-
- **input_params** (Tensor) -
|
|
3842
|
-
|
|
3843
|
-
- **input_indices** (Tensor) -
|
|
3844
|
-
|
|
3841
|
+
- **input_params** (Tensor) - a Tensor slice, the shape is :math:`(x_1, x_2, ..., x_R)`.
|
|
3842
|
+
Currently, the dimension is restricted to be 2.
|
|
3843
|
+
- **input_indices** (Tensor) - Specifies the indices of elements of the original Tensor.
|
|
3844
|
+
The shape is :math:`(y_1, y_2, ..., y_S)`.
|
|
3845
|
+
Values can be out of range of `input_params`,
|
|
3845
3846
|
and the exceeding part will be filled with 0 in the output. Values do not support negative and the result
|
|
3846
3847
|
is undefined if values are negative. The data type should be int32 or int64.
|
|
3847
3848
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
|
@@ -4068,7 +4069,8 @@ class TensorScatterUpdate(_TensorScatterOp):
|
|
|
4068
4069
|
r"""
|
|
4069
4070
|
Creates a new tensor by updating the positions in `input_x` indicated by
|
|
4070
4071
|
`indices`, with values from `update`. This operation is almost equivalent to using
|
|
4071
|
-
:class:`mindspore.ops.ScatterNdUpdate` , except that the updates are applied on `
|
|
4072
|
+
:class:`mindspore.ops.ScatterNdUpdate` , except that the updates are applied on output `Tensor`
|
|
4073
|
+
instead of `input_x`.
|
|
4072
4074
|
|
|
4073
4075
|
`indices` must have rank at least 2, the last axis is the depth of each index
|
|
4074
4076
|
vectors. For each index vector, there must be a corresponding value in `update`. If
|
|
@@ -4230,12 +4232,25 @@ class TensorScatterSub(Primitive):
|
|
|
4230
4232
|
r"""
|
|
4231
4233
|
Creates a new tensor by subtracting the values from the positions in `input_x` indicated by
|
|
4232
4234
|
`indices`, with values from `updates`. When multiple values are provided for the same
|
|
4233
|
-
index, the result of the update will
|
|
4235
|
+
index, the result of the update will subtract these values respectively. This operation is almost
|
|
4234
4236
|
equivalent to using :class:`mindspore.ops.ScatterNdSub` , except that the updates are applied on output `Tensor`
|
|
4235
4237
|
instead of input `Parameter`.
|
|
4236
4238
|
|
|
4237
|
-
..
|
|
4238
|
-
|
|
4239
|
+
.. code-block:: python
|
|
4240
|
+
|
|
4241
|
+
# Iterate through all index
|
|
4242
|
+
for i in range(indices.shape[0]):
|
|
4243
|
+
for j in range(indices.shape[1]):
|
|
4244
|
+
...
|
|
4245
|
+
for k in range(indices.shape[-2]): # The last dimension is coordinate dimension
|
|
4246
|
+
# Get current index combination
|
|
4247
|
+
index_tuple = (i, j, ..., k)
|
|
4248
|
+
# Get target position
|
|
4249
|
+
target_index = indices[index_tuple]
|
|
4250
|
+
# Get corresponding update value
|
|
4251
|
+
update_value = updates[index_tuple]
|
|
4252
|
+
# Perform subtraction operation
|
|
4253
|
+
output[target_index] -= update_value
|
|
4239
4254
|
|
|
4240
4255
|
Refer to :func:`mindspore.ops.tensor_scatter_sub` for more details.
|
|
4241
4256
|
|
|
@@ -5522,7 +5537,7 @@ class AffineGrid(Primitive):
|
|
|
5522
5537
|
|
|
5523
5538
|
Args:
|
|
5524
5539
|
align_corners (bool, optional): Geometrically, each pixel of input is viewed as a squqre instead of dot.
|
|
5525
|
-
If True
|
|
5540
|
+
If ``True``, consider extremum -1 and 1 referring to the centers of the pixels rather than pixel corners.
|
|
5526
5541
|
The default value is ``False`` , extremum -1 and 1 refer to the corners of the pixels, so that sampling is
|
|
5527
5542
|
irrelevant to resolution of the image. Default: ``False`` .
|
|
5528
5543
|
|
|
@@ -5534,7 +5549,7 @@ class AffineGrid(Primitive):
|
|
|
5534
5549
|
or :math:`(N, C, D, H, W)` for 3D grid.
|
|
5535
5550
|
|
|
5536
5551
|
Outputs:
|
|
5537
|
-
Tensor, a tensor whose data type is same as
|
|
5552
|
+
Tensor, a tensor whose data type is same as `theta`, and the shape is :math:`(N, H, W, 2)` for 2D grid
|
|
5538
5553
|
or :math:`(N, D, H, W, 3)` for 3D grid.
|
|
5539
5554
|
|
|
5540
5555
|
Supported Platforms:
|
|
@@ -5743,7 +5758,7 @@ class TopK(Primitive):
|
|
|
5743
5758
|
- CPU: all numeric types.
|
|
5744
5759
|
|
|
5745
5760
|
- **k** (Union(Tensor, int)) - The number of top elements to be computed along the last dimension.
|
|
5746
|
-
|
|
5761
|
+
The supported dtype is int32 and it should be 0-D or 1-D Tensor with shape :math:`(1, )` .
|
|
5747
5762
|
|
|
5748
5763
|
Outputs:
|
|
5749
5764
|
A tuple consisting of `values` and `indexes`.
|