mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,7 +14,8 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from mindspore.common._stub_tensor import _convert_stub
|
|
17
|
-
from mindspore.ops.
|
|
17
|
+
from mindspore.ops._utils.arg_handler import *
|
|
18
|
+
from mindspore._c_expression import AdaptiveMaxPool2DPrim_
|
|
18
19
|
from mindspore._c_expression import ArgMaxWithValuePrim_
|
|
19
20
|
from mindspore._c_expression import ArgMinWithValuePrim_
|
|
20
21
|
from mindspore._c_expression import BatchMatMulPrim_
|
|
@@ -66,23 +67,34 @@ from mindspore._c_expression import SmoothL1LossPrim_
|
|
|
66
67
|
from mindspore._c_expression import SoftmaxPrim_
|
|
67
68
|
from mindspore._c_expression import SoftShrinkGradPrim_
|
|
68
69
|
from mindspore._c_expression import SoftShrinkPrim_
|
|
70
|
+
from mindspore._c_expression import SoftMarginLossGradPrim_
|
|
71
|
+
from mindspore._c_expression import SoftMarginLossPrim_
|
|
69
72
|
from mindspore._c_expression import SplitPrim_
|
|
70
73
|
from mindspore._c_expression import SqueezePrim_
|
|
71
74
|
from mindspore._c_expression import StackExtPrim_
|
|
72
|
-
from mindspore._c_expression import TrilExtPrim_
|
|
73
75
|
from mindspore._c_expression import TriuPrim_
|
|
74
76
|
from mindspore._c_expression import UniqueConsecutivePrim_
|
|
75
77
|
from mindspore._c_expression import UpsampleTrilinear3DGradPrim_
|
|
76
78
|
from mindspore._c_expression import UpsampleTrilinear3DPrim_
|
|
79
|
+
from mindspore._c_expression import FusedInferAttentionScorePrim_
|
|
77
80
|
from mindspore._c_expression import GroupedMatmulPrim_
|
|
78
81
|
from mindspore._c_expression import QuantBatchMatmulPrim_
|
|
79
82
|
from mindspore._c_expression import WeightQuantBatchMatmulPrim_
|
|
80
83
|
|
|
81
84
|
|
|
85
|
+
class _PyboostAdaptiveMaxPool2DPrim(AdaptiveMaxPool2DPrim_):
|
|
86
|
+
def __call__(self, input, output_size):
|
|
87
|
+
|
|
88
|
+
return super().__call__([input, output_size])
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
adaptive_max_pool2d_impl = _PyboostAdaptiveMaxPool2DPrim()
|
|
92
|
+
|
|
93
|
+
|
|
82
94
|
class _PyboostArgMaxWithValuePrim(ArgMaxWithValuePrim_):
|
|
83
95
|
def __call__(self, input, axis, keep_dims):
|
|
84
96
|
|
|
85
|
-
return
|
|
97
|
+
return super().__call__([input, axis, keep_dims])
|
|
86
98
|
|
|
87
99
|
|
|
88
100
|
argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
|
|
@@ -91,7 +103,7 @@ argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
|
|
|
91
103
|
class _PyboostArgMinWithValuePrim(ArgMinWithValuePrim_):
|
|
92
104
|
def __call__(self, input, axis, keep_dims):
|
|
93
105
|
|
|
94
|
-
return
|
|
106
|
+
return super().__call__([input, axis, keep_dims])
|
|
95
107
|
|
|
96
108
|
|
|
97
109
|
argmin_with_value_impl = _PyboostArgMinWithValuePrim()
|
|
@@ -100,7 +112,7 @@ argmin_with_value_impl = _PyboostArgMinWithValuePrim()
|
|
|
100
112
|
class _PyboostBatchMatMulPrim(BatchMatMulPrim_):
|
|
101
113
|
def __call__(self, x, y, transpose_a, transpose_b):
|
|
102
114
|
|
|
103
|
-
return
|
|
115
|
+
return super().__call__([x, y, transpose_a, transpose_b])
|
|
104
116
|
|
|
105
117
|
|
|
106
118
|
batch_mat_mul_impl = _PyboostBatchMatMulPrim()
|
|
@@ -109,7 +121,7 @@ batch_mat_mul_impl = _PyboostBatchMatMulPrim()
|
|
|
109
121
|
class _PyboostBatchNormGradExtPrim(BatchNormGradExtPrim_):
|
|
110
122
|
def __call__(self, dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask):
|
|
111
123
|
|
|
112
|
-
return
|
|
124
|
+
return super().__call__([dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask])
|
|
113
125
|
|
|
114
126
|
|
|
115
127
|
batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
|
|
@@ -118,7 +130,7 @@ batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
|
|
|
118
130
|
class _PyboostBinaryCrossEntropyGradPrim(BinaryCrossEntropyGradPrim_):
|
|
119
131
|
def __call__(self, input, target, grad_output, weight, reduction):
|
|
120
132
|
converted_reduction = str_to_enum('binary_cross_entropy_grad', 'reduction', reduction)
|
|
121
|
-
return
|
|
133
|
+
return super().__call__([input, target, grad_output, weight, converted_reduction])
|
|
122
134
|
|
|
123
135
|
|
|
124
136
|
binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
|
|
@@ -127,7 +139,7 @@ binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
|
|
|
127
139
|
class _PyboostBinaryCrossEntropyPrim(BinaryCrossEntropyPrim_):
|
|
128
140
|
def __call__(self, input, target, weight, reduction):
|
|
129
141
|
converted_reduction = str_to_enum('binary_cross_entropy', 'reduction', reduction)
|
|
130
|
-
return
|
|
142
|
+
return super().__call__([input, target, weight, converted_reduction])
|
|
131
143
|
|
|
132
144
|
|
|
133
145
|
binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
|
|
@@ -136,7 +148,7 @@ binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
|
|
|
136
148
|
class _PyboostBCEWithLogitsLossPrim(BCEWithLogitsLossPrim_):
|
|
137
149
|
def __call__(self, input, target, weight, posWeight, reduction):
|
|
138
150
|
converted_reduction = str_to_enum('binary_cross_entropy_with_logits', 'reduction', reduction)
|
|
139
|
-
return
|
|
151
|
+
return super().__call__([input, target, weight, posWeight, converted_reduction])
|
|
140
152
|
|
|
141
153
|
|
|
142
154
|
binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
|
|
@@ -145,7 +157,7 @@ binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
|
|
|
145
157
|
class _PyboostBroadcastToPrim(BroadcastToPrim_):
|
|
146
158
|
def __call__(self, input, shape):
|
|
147
159
|
|
|
148
|
-
return
|
|
160
|
+
return super().__call__([input, shape])
|
|
149
161
|
|
|
150
162
|
|
|
151
163
|
broadcast_to_impl = _PyboostBroadcastToPrim()
|
|
@@ -154,7 +166,7 @@ broadcast_to_impl = _PyboostBroadcastToPrim()
|
|
|
154
166
|
class _PyboostConcatPrim(ConcatPrim_):
|
|
155
167
|
def __call__(self, tensors, axis):
|
|
156
168
|
|
|
157
|
-
return
|
|
169
|
+
return super().__call__([tensors, axis])
|
|
158
170
|
|
|
159
171
|
|
|
160
172
|
concat_impl = _PyboostConcatPrim()
|
|
@@ -163,7 +175,7 @@ concat_impl = _PyboostConcatPrim()
|
|
|
163
175
|
class _PyboostCrossPrim(CrossPrim_):
|
|
164
176
|
def __call__(self, input, other, dim):
|
|
165
177
|
|
|
166
|
-
return
|
|
178
|
+
return super().__call__([input, other, dim])
|
|
167
179
|
|
|
168
180
|
|
|
169
181
|
cross_impl = _PyboostCrossPrim()
|
|
@@ -172,7 +184,7 @@ cross_impl = _PyboostCrossPrim()
|
|
|
172
184
|
class _PyboostCummaxPrim(CummaxPrim_):
|
|
173
185
|
def __call__(self, input, axis):
|
|
174
186
|
|
|
175
|
-
return
|
|
187
|
+
return super().__call__([input, axis])
|
|
176
188
|
|
|
177
189
|
|
|
178
190
|
cummax_impl = _PyboostCummaxPrim()
|
|
@@ -181,7 +193,7 @@ cummax_impl = _PyboostCummaxPrim()
|
|
|
181
193
|
class _PyboostEluExtPrim(EluExtPrim_):
|
|
182
194
|
def __call__(self, input, alpha):
|
|
183
195
|
|
|
184
|
-
return
|
|
196
|
+
return super().__call__([input, alpha])
|
|
185
197
|
|
|
186
198
|
|
|
187
199
|
elu_ext_impl = _PyboostEluExtPrim()
|
|
@@ -190,7 +202,7 @@ elu_ext_impl = _PyboostEluExtPrim()
|
|
|
190
202
|
class _PyboostFFNExtPrim(FFNExtPrim_):
|
|
191
203
|
def __call__(self, x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise):
|
|
192
204
|
converted_activation = str_to_enum('ffn_ext', 'activation', activation)
|
|
193
|
-
return
|
|
205
|
+
return super().__call__([x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise])
|
|
194
206
|
|
|
195
207
|
|
|
196
208
|
ffn_ext_impl = _PyboostFFNExtPrim()
|
|
@@ -199,7 +211,7 @@ ffn_ext_impl = _PyboostFFNExtPrim()
|
|
|
199
211
|
class _PyboostFlashAttentionScoreGradPrim(FlashAttentionScoreGradPrim_):
|
|
200
212
|
def __call__(self, query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
|
|
201
213
|
converted_input_layout = str_to_enum('flash_attention_score_grad', 'input_layout', input_layout)
|
|
202
|
-
return
|
|
214
|
+
return super().__call__([query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
|
|
203
215
|
|
|
204
216
|
|
|
205
217
|
flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
|
|
@@ -208,7 +220,7 @@ flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
|
|
|
208
220
|
class _PyboostFlashAttentionScorePrim(FlashAttentionScorePrim_):
|
|
209
221
|
def __call__(self, query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
|
|
210
222
|
converted_input_layout = str_to_enum('flash_attention_score', 'input_layout', input_layout)
|
|
211
|
-
return
|
|
223
|
+
return super().__call__([query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
|
|
212
224
|
|
|
213
225
|
|
|
214
226
|
flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
|
|
@@ -217,7 +229,7 @@ flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
|
|
|
217
229
|
class _PyboostGluGradPrim(GluGradPrim_):
|
|
218
230
|
def __call__(self, grads, x, axis):
|
|
219
231
|
|
|
220
|
-
return
|
|
232
|
+
return super().__call__([grads, x, axis])
|
|
221
233
|
|
|
222
234
|
|
|
223
235
|
glu_grad_impl = _PyboostGluGradPrim()
|
|
@@ -226,7 +238,7 @@ glu_grad_impl = _PyboostGluGradPrim()
|
|
|
226
238
|
class _PyboostGLUPrim(GLUPrim_):
|
|
227
239
|
def __call__(self, x, axis):
|
|
228
240
|
|
|
229
|
-
return
|
|
241
|
+
return super().__call__([x, axis])
|
|
230
242
|
|
|
231
243
|
|
|
232
244
|
glu_impl = _PyboostGLUPrim()
|
|
@@ -236,7 +248,7 @@ class _PyboostGridSampler2DGradPrim(GridSampler2DGradPrim_):
|
|
|
236
248
|
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
|
|
237
249
|
converted_interpolation_mode = str_to_enum('grid_sampler_2d_grad', 'interpolation_mode', interpolation_mode)
|
|
238
250
|
converted_padding_mode = str_to_enum('grid_sampler_2d_grad', 'padding_mode', padding_mode)
|
|
239
|
-
return
|
|
251
|
+
return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
|
|
240
252
|
|
|
241
253
|
|
|
242
254
|
grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
|
|
@@ -246,7 +258,7 @@ class _PyboostGridSampler2DPrim(GridSampler2DPrim_):
|
|
|
246
258
|
def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
247
259
|
converted_interpolation_mode = str_to_enum('grid_sampler_2d', 'interpolation_mode', interpolation_mode)
|
|
248
260
|
converted_padding_mode = str_to_enum('grid_sampler_2d', 'padding_mode', padding_mode)
|
|
249
|
-
return
|
|
261
|
+
return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
|
|
250
262
|
|
|
251
263
|
|
|
252
264
|
grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
|
|
@@ -256,7 +268,7 @@ class _PyboostGridSampler3DGradPrim(GridSampler3DGradPrim_):
|
|
|
256
268
|
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
|
|
257
269
|
converted_interpolation_mode = str_to_enum('grid_sampler_3d_grad', 'interpolation_mode', interpolation_mode)
|
|
258
270
|
converted_padding_mode = str_to_enum('grid_sampler_3d_grad', 'padding_mode', padding_mode)
|
|
259
|
-
return
|
|
271
|
+
return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
|
|
260
272
|
|
|
261
273
|
|
|
262
274
|
grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
|
|
@@ -266,7 +278,7 @@ class _PyboostGridSampler3DPrim(GridSampler3DPrim_):
|
|
|
266
278
|
def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
267
279
|
converted_interpolation_mode = str_to_enum('grid_sampler_3d', 'interpolation_mode', interpolation_mode)
|
|
268
280
|
converted_padding_mode = str_to_enum('grid_sampler_3d', 'padding_mode', padding_mode)
|
|
269
|
-
return
|
|
281
|
+
return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
|
|
270
282
|
|
|
271
283
|
|
|
272
284
|
grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
|
|
@@ -275,7 +287,7 @@ grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
|
|
|
275
287
|
class _PyboostHShrinkGradPrim(HShrinkGradPrim_):
|
|
276
288
|
def __call__(self, gradients, features, lambd):
|
|
277
289
|
|
|
278
|
-
return
|
|
290
|
+
return super().__call__([gradients, features, lambd])
|
|
279
291
|
|
|
280
292
|
|
|
281
293
|
hshrink_grad_impl = _PyboostHShrinkGradPrim()
|
|
@@ -284,7 +296,7 @@ hshrink_grad_impl = _PyboostHShrinkGradPrim()
|
|
|
284
296
|
class _PyboostHShrinkPrim(HShrinkPrim_):
|
|
285
297
|
def __call__(self, input, lambd):
|
|
286
298
|
|
|
287
|
-
return
|
|
299
|
+
return super().__call__([input, lambd])
|
|
288
300
|
|
|
289
301
|
|
|
290
302
|
hshrink_impl = _PyboostHShrinkPrim()
|
|
@@ -293,7 +305,7 @@ hshrink_impl = _PyboostHShrinkPrim()
|
|
|
293
305
|
class _PyboostIncreFlashAttentionPrim(IncreFlashAttentionPrim_):
|
|
294
306
|
def __call__(self, query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, input_layout, scale_value, num_key_value_heads, block_size, inner_precise):
|
|
295
307
|
converted_input_layout = str_to_enum('incre_flash_attention', 'input_layout', input_layout)
|
|
296
|
-
return
|
|
308
|
+
return super().__call__([query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise])
|
|
297
309
|
|
|
298
310
|
|
|
299
311
|
incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
|
|
@@ -302,7 +314,7 @@ incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
|
|
|
302
314
|
class _PyboostIsClosePrim(IsClosePrim_):
|
|
303
315
|
def __call__(self, input, other, rtol, atol, equal_nan):
|
|
304
316
|
|
|
305
|
-
return
|
|
317
|
+
return super().__call__([input, other, rtol, atol, equal_nan])
|
|
306
318
|
|
|
307
319
|
|
|
308
320
|
isclose_impl = _PyboostIsClosePrim()
|
|
@@ -311,7 +323,7 @@ isclose_impl = _PyboostIsClosePrim()
|
|
|
311
323
|
class _PyboostLogSoftmaxGradPrim(LogSoftmaxGradPrim_):
|
|
312
324
|
def __call__(self, logits, grad, axis):
|
|
313
325
|
|
|
314
|
-
return
|
|
326
|
+
return super().__call__([logits, grad, axis])
|
|
315
327
|
|
|
316
328
|
|
|
317
329
|
log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
|
|
@@ -320,7 +332,7 @@ log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
|
|
|
320
332
|
class _PyboostLogSoftmaxPrim(LogSoftmaxPrim_):
|
|
321
333
|
def __call__(self, logits, axis):
|
|
322
334
|
|
|
323
|
-
return
|
|
335
|
+
return super().__call__([logits, axis])
|
|
324
336
|
|
|
325
337
|
|
|
326
338
|
log_softmax_impl = _PyboostLogSoftmaxPrim()
|
|
@@ -329,7 +341,7 @@ log_softmax_impl = _PyboostLogSoftmaxPrim()
|
|
|
329
341
|
class _PyboostMatMulPrim(MatMulPrim_):
|
|
330
342
|
def __call__(self, input, mat2, transpose_a, transpose_b):
|
|
331
343
|
|
|
332
|
-
return
|
|
344
|
+
return super().__call__([input, mat2, transpose_a, transpose_b])
|
|
333
345
|
|
|
334
346
|
|
|
335
347
|
matmul_impl = _PyboostMatMulPrim()
|
|
@@ -341,7 +353,7 @@ class _PyboostMaxPoolGradWithIndicesPrim(MaxPoolGradWithIndicesPrim_):
|
|
|
341
353
|
converted_strides = to_strides('max_pool_grad_with_indices', 'strides', strides)
|
|
342
354
|
converted_pads = to_output_padding('max_pool_grad_with_indices', 'pads', pads)
|
|
343
355
|
converted_dilation = to_dilations('max_pool_grad_with_indices', 'dilation', dilation)
|
|
344
|
-
return
|
|
356
|
+
return super().__call__([x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
345
357
|
|
|
346
358
|
|
|
347
359
|
max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
|
|
@@ -353,7 +365,7 @@ class _PyboostMaxPoolGradWithMaskPrim(MaxPoolGradWithMaskPrim_):
|
|
|
353
365
|
converted_strides = to_strides('max_pool_grad_with_mask', 'strides', strides)
|
|
354
366
|
converted_pads = to_output_padding('max_pool_grad_with_mask', 'pads', pads)
|
|
355
367
|
converted_dilation = to_dilations('max_pool_grad_with_mask', 'dilation', dilation)
|
|
356
|
-
return
|
|
368
|
+
return super().__call__([x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
357
369
|
|
|
358
370
|
|
|
359
371
|
max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
|
|
@@ -365,7 +377,7 @@ class _PyboostMaxPoolWithIndicesPrim(MaxPoolWithIndicesPrim_):
|
|
|
365
377
|
converted_strides = to_strides('max_pool_with_indices', 'strides', strides)
|
|
366
378
|
converted_pads = to_output_padding('max_pool_with_indices', 'pads', pads)
|
|
367
379
|
converted_dilation = to_dilations('max_pool_with_indices', 'dilation', dilation)
|
|
368
|
-
return
|
|
380
|
+
return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
369
381
|
|
|
370
382
|
|
|
371
383
|
max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
|
|
@@ -377,7 +389,7 @@ class _PyboostMaxPoolWithMaskPrim(MaxPoolWithMaskPrim_):
|
|
|
377
389
|
converted_strides = to_strides('max_pool_with_mask', 'strides', strides)
|
|
378
390
|
converted_pads = to_output_padding('max_pool_with_mask', 'pads', pads)
|
|
379
391
|
converted_dilation = to_dilations('max_pool_with_mask', 'dilation', dilation)
|
|
380
|
-
return
|
|
392
|
+
return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
381
393
|
|
|
382
394
|
|
|
383
395
|
max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
|
|
@@ -386,7 +398,7 @@ max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
|
|
|
386
398
|
class _PyboostMeshgridPrim(MeshgridPrim_):
|
|
387
399
|
def __call__(self, inputs, indexing):
|
|
388
400
|
converted_indexing = str_to_enum('meshgrid', 'indexing', indexing)
|
|
389
|
-
return
|
|
401
|
+
return super().__call__([inputs, converted_indexing])
|
|
390
402
|
|
|
391
403
|
|
|
392
404
|
meshgrid_impl = _PyboostMeshgridPrim()
|
|
@@ -395,7 +407,7 @@ meshgrid_impl = _PyboostMeshgridPrim()
|
|
|
395
407
|
class _PyboostNanToNumPrim(NanToNumPrim_):
|
|
396
408
|
def __call__(self, input, nan, posinf, neginf):
|
|
397
409
|
|
|
398
|
-
return
|
|
410
|
+
return super().__call__([input, nan, posinf, neginf])
|
|
399
411
|
|
|
400
412
|
|
|
401
413
|
nan_to_num_impl = _PyboostNanToNumPrim()
|
|
@@ -404,7 +416,7 @@ nan_to_num_impl = _PyboostNanToNumPrim()
|
|
|
404
416
|
class _PyboostNLLLossGradPrim(NLLLossGradPrim_):
|
|
405
417
|
def __call__(self, logits, loss_grad, labels, weight, total_weight, reduction, ignore_index):
|
|
406
418
|
converted_reduction = str_to_enum('nllloss_grad', 'reduction', reduction)
|
|
407
|
-
return
|
|
419
|
+
return super().__call__([logits, loss_grad, labels, weight, total_weight, converted_reduction, ignore_index])
|
|
408
420
|
|
|
409
421
|
|
|
410
422
|
nllloss_grad_impl = _PyboostNLLLossGradPrim()
|
|
@@ -413,7 +425,7 @@ nllloss_grad_impl = _PyboostNLLLossGradPrim()
|
|
|
413
425
|
class _PyboostNLLLossPrim(NLLLossPrim_):
|
|
414
426
|
def __call__(self, logits, labels, weight, reduction, ignore_index):
|
|
415
427
|
converted_reduction = str_to_enum('nllloss', 'reduction', reduction)
|
|
416
|
-
return
|
|
428
|
+
return super().__call__([logits, labels, weight, converted_reduction, ignore_index])
|
|
417
429
|
|
|
418
430
|
|
|
419
431
|
nllloss_impl = _PyboostNLLLossPrim()
|
|
@@ -422,7 +434,7 @@ nllloss_impl = _PyboostNLLLossPrim()
|
|
|
422
434
|
class _PyboostOneHotExtPrim(OneHotExtPrim_):
|
|
423
435
|
def __call__(self, tensor, num_classes, on_value, off_value, axis):
|
|
424
436
|
|
|
425
|
-
return
|
|
437
|
+
return super().__call__([tensor, num_classes, on_value, off_value, axis])
|
|
426
438
|
|
|
427
439
|
|
|
428
440
|
one_hot_ext_impl = _PyboostOneHotExtPrim()
|
|
@@ -431,7 +443,7 @@ one_hot_ext_impl = _PyboostOneHotExtPrim()
|
|
|
431
443
|
class _PyboostPromptFlashAttentionPrim(PromptFlashAttentionPrim_):
|
|
432
444
|
def __call__(self, query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise):
|
|
433
445
|
converted_input_layout = str_to_enum('prompt_flash_attention', 'input_layout', input_layout)
|
|
434
|
-
return
|
|
446
|
+
return super().__call__([query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise])
|
|
435
447
|
|
|
436
448
|
|
|
437
449
|
prompt_flash_attention_impl = _PyboostPromptFlashAttentionPrim()
|
|
@@ -440,7 +452,7 @@ prompt_flash_attention_impl = _PyboostPromptFlashAttentionPrim()
|
|
|
440
452
|
class _PyboostReduceAllPrim(ReduceAllPrim_):
|
|
441
453
|
def __call__(self, input, axis, keep_dims):
|
|
442
454
|
|
|
443
|
-
return
|
|
455
|
+
return super().__call__([input, axis, keep_dims])
|
|
444
456
|
|
|
445
457
|
|
|
446
458
|
reduce_all_impl = _PyboostReduceAllPrim()
|
|
@@ -449,7 +461,7 @@ reduce_all_impl = _PyboostReduceAllPrim()
|
|
|
449
461
|
class _PyboostReduceAnyPrim(ReduceAnyPrim_):
|
|
450
462
|
def __call__(self, x, axis, keep_dims):
|
|
451
463
|
|
|
452
|
-
return
|
|
464
|
+
return super().__call__([x, axis, keep_dims])
|
|
453
465
|
|
|
454
466
|
|
|
455
467
|
reduce_any_impl = _PyboostReduceAnyPrim()
|
|
@@ -458,7 +470,7 @@ reduce_any_impl = _PyboostReduceAnyPrim()
|
|
|
458
470
|
class _PyboostReduceMaxPrim(ReduceMaxPrim_):
|
|
459
471
|
def __call__(self, x, axis, keep_dims):
|
|
460
472
|
|
|
461
|
-
return
|
|
473
|
+
return super().__call__([x, axis, keep_dims])
|
|
462
474
|
|
|
463
475
|
|
|
464
476
|
reduce_max_impl = _PyboostReduceMaxPrim()
|
|
@@ -467,7 +479,7 @@ reduce_max_impl = _PyboostReduceMaxPrim()
|
|
|
467
479
|
class _PyboostReduceMinPrim(ReduceMinPrim_):
|
|
468
480
|
def __call__(self, x, axis, keep_dims):
|
|
469
481
|
|
|
470
|
-
return
|
|
482
|
+
return super().__call__([x, axis, keep_dims])
|
|
471
483
|
|
|
472
484
|
|
|
473
485
|
reduce_min_impl = _PyboostReduceMinPrim()
|
|
@@ -476,7 +488,7 @@ reduce_min_impl = _PyboostReduceMinPrim()
|
|
|
476
488
|
class _PyboostReverseV2Prim(ReverseV2Prim_):
|
|
477
489
|
def __call__(self, input, axis):
|
|
478
490
|
|
|
479
|
-
return
|
|
491
|
+
return super().__call__([input, axis])
|
|
480
492
|
|
|
481
493
|
|
|
482
494
|
reverse_v2_impl = _PyboostReverseV2Prim()
|
|
@@ -485,16 +497,16 @@ reverse_v2_impl = _PyboostReverseV2Prim()
|
|
|
485
497
|
class _PyboostRmsNormPrim(RmsNormPrim_):
|
|
486
498
|
def __call__(self, x, gamma, epsilon):
|
|
487
499
|
|
|
488
|
-
return
|
|
500
|
+
return super().__call__([x, gamma, epsilon])
|
|
489
501
|
|
|
490
502
|
|
|
491
503
|
rms_norm_impl = _PyboostRmsNormPrim()
|
|
492
504
|
|
|
493
505
|
|
|
494
506
|
class _PyboostRollPrim(RollPrim_):
|
|
495
|
-
def __call__(self, input,
|
|
507
|
+
def __call__(self, input, shifts, dims):
|
|
496
508
|
|
|
497
|
-
return
|
|
509
|
+
return super().__call__([input, shifts, dims])
|
|
498
510
|
|
|
499
511
|
|
|
500
512
|
roll_impl = _PyboostRollPrim()
|
|
@@ -503,7 +515,7 @@ roll_impl = _PyboostRollPrim()
|
|
|
503
515
|
class _PyboostSearchSortedPrim(SearchSortedPrim_):
|
|
504
516
|
def __call__(self, sorted_sequence, values, sorter, dtype, right):
|
|
505
517
|
|
|
506
|
-
return
|
|
518
|
+
return super().__call__([sorted_sequence, values, sorter, dtype, right])
|
|
507
519
|
|
|
508
520
|
|
|
509
521
|
searchsorted_impl = _PyboostSearchSortedPrim()
|
|
@@ -512,7 +524,7 @@ searchsorted_impl = _PyboostSearchSortedPrim()
|
|
|
512
524
|
class _PyboostSmoothL1LossGradPrim(SmoothL1LossGradPrim_):
|
|
513
525
|
def __call__(self, prediction, target, dout, beta, reduction):
|
|
514
526
|
converted_reduction = str_to_enum('smooth_l1_loss_grad', 'reduction', reduction)
|
|
515
|
-
return
|
|
527
|
+
return super().__call__([prediction, target, dout, beta, converted_reduction])
|
|
516
528
|
|
|
517
529
|
|
|
518
530
|
smooth_l1_loss_grad_impl = _PyboostSmoothL1LossGradPrim()
|
|
@@ -521,7 +533,7 @@ smooth_l1_loss_grad_impl = _PyboostSmoothL1LossGradPrim()
|
|
|
521
533
|
class _PyboostSmoothL1LossPrim(SmoothL1LossPrim_):
|
|
522
534
|
def __call__(self, prediction, target, beta, reduction):
|
|
523
535
|
converted_reduction = str_to_enum('smooth_l1_loss', 'reduction', reduction)
|
|
524
|
-
return
|
|
536
|
+
return super().__call__([prediction, target, beta, converted_reduction])
|
|
525
537
|
|
|
526
538
|
|
|
527
539
|
smooth_l1_loss_impl = _PyboostSmoothL1LossPrim()
|
|
@@ -530,7 +542,7 @@ smooth_l1_loss_impl = _PyboostSmoothL1LossPrim()
|
|
|
530
542
|
class _PyboostSoftmaxPrim(SoftmaxPrim_):
|
|
531
543
|
def __call__(self, input, axis):
|
|
532
544
|
|
|
533
|
-
return
|
|
545
|
+
return super().__call__([input, axis])
|
|
534
546
|
|
|
535
547
|
|
|
536
548
|
softmax_impl = _PyboostSoftmaxPrim()
|
|
@@ -539,7 +551,7 @@ softmax_impl = _PyboostSoftmaxPrim()
|
|
|
539
551
|
class _PyboostSoftShrinkGradPrim(SoftShrinkGradPrim_):
|
|
540
552
|
def __call__(self, input_grad, input_x, lambd):
|
|
541
553
|
|
|
542
|
-
return
|
|
554
|
+
return super().__call__([input_grad, input_x, lambd])
|
|
543
555
|
|
|
544
556
|
|
|
545
557
|
softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
|
|
@@ -548,16 +560,34 @@ softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
|
|
|
548
560
|
class _PyboostSoftShrinkPrim(SoftShrinkPrim_):
|
|
549
561
|
def __call__(self, input, lambd):
|
|
550
562
|
|
|
551
|
-
return
|
|
563
|
+
return super().__call__([input, lambd])
|
|
552
564
|
|
|
553
565
|
|
|
554
566
|
softshrink_impl = _PyboostSoftShrinkPrim()
|
|
555
567
|
|
|
556
568
|
|
|
569
|
+
class _PyboostSoftMarginLossGradPrim(SoftMarginLossGradPrim_):
|
|
570
|
+
def __call__(self, predict, label, dout, reduction):
|
|
571
|
+
converted_reduction = str_to_enum('soft_margin_loss_grad', 'reduction', reduction)
|
|
572
|
+
return super().__call__([predict, label, dout, converted_reduction])
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
soft_margin_loss_grad_impl = _PyboostSoftMarginLossGradPrim()
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class _PyboostSoftMarginLossPrim(SoftMarginLossPrim_):
|
|
579
|
+
def __call__(self, input, target, reduction):
|
|
580
|
+
converted_reduction = str_to_enum('soft_margin_loss', 'reduction', reduction)
|
|
581
|
+
return super().__call__([input, target, converted_reduction])
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
soft_margin_loss_impl = _PyboostSoftMarginLossPrim()
|
|
585
|
+
|
|
586
|
+
|
|
557
587
|
class _PyboostSplitPrim(SplitPrim_):
|
|
558
588
|
def __call__(self, input_x, axis, output_num):
|
|
559
589
|
|
|
560
|
-
return
|
|
590
|
+
return super().__call__([input_x, axis, output_num])
|
|
561
591
|
|
|
562
592
|
|
|
563
593
|
split_impl = _PyboostSplitPrim()
|
|
@@ -566,7 +596,7 @@ split_impl = _PyboostSplitPrim()
|
|
|
566
596
|
class _PyboostSqueezePrim(SqueezePrim_):
|
|
567
597
|
def __call__(self, input, axis):
|
|
568
598
|
|
|
569
|
-
return
|
|
599
|
+
return super().__call__([input, axis])
|
|
570
600
|
|
|
571
601
|
|
|
572
602
|
squeeze_impl = _PyboostSqueezePrim()
|
|
@@ -575,34 +605,25 @@ squeeze_impl = _PyboostSqueezePrim()
|
|
|
575
605
|
class _PyboostStackExtPrim(StackExtPrim_):
|
|
576
606
|
def __call__(self, tensors, dim):
|
|
577
607
|
|
|
578
|
-
return
|
|
608
|
+
return super().__call__([tensors, dim])
|
|
579
609
|
|
|
580
610
|
|
|
581
611
|
stack_ext_impl = _PyboostStackExtPrim()
|
|
582
612
|
|
|
583
613
|
|
|
584
|
-
class _PyboostTrilExtPrim(TrilExtPrim_):
|
|
585
|
-
def __call__(self, input, diagonal):
|
|
586
|
-
|
|
587
|
-
return _convert_stub(super().__call__([input, diagonal]))
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
tril_ext_impl = _PyboostTrilExtPrim()
|
|
591
|
-
|
|
592
|
-
|
|
593
614
|
class _PyboostTriuPrim(TriuPrim_):
|
|
594
615
|
def __call__(self, input, diagonal):
|
|
595
616
|
|
|
596
|
-
return
|
|
617
|
+
return super().__call__([input, diagonal])
|
|
597
618
|
|
|
598
619
|
|
|
599
620
|
triu_impl = _PyboostTriuPrim()
|
|
600
621
|
|
|
601
622
|
|
|
602
623
|
class _PyboostUniqueConsecutivePrim(UniqueConsecutivePrim_):
|
|
603
|
-
def __call__(self, input,
|
|
624
|
+
def __call__(self, input, return_inverse, return_counts, dim):
|
|
604
625
|
|
|
605
|
-
return
|
|
626
|
+
return super().__call__([input, return_inverse, return_counts, dim])
|
|
606
627
|
|
|
607
628
|
|
|
608
629
|
unique_consecutive_impl = _PyboostUniqueConsecutivePrim()
|
|
@@ -611,7 +632,7 @@ unique_consecutive_impl = _PyboostUniqueConsecutivePrim()
|
|
|
611
632
|
class _PyboostUpsampleTrilinear3DGradPrim(UpsampleTrilinear3DGradPrim_):
|
|
612
633
|
def __call__(self, dy, input_size, output_size, scales, align_corners):
|
|
613
634
|
|
|
614
|
-
return
|
|
635
|
+
return super().__call__([dy, input_size, output_size, scales, align_corners])
|
|
615
636
|
|
|
616
637
|
|
|
617
638
|
upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
|
|
@@ -620,16 +641,25 @@ upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
|
|
|
620
641
|
class _PyboostUpsampleTrilinear3DPrim(UpsampleTrilinear3DPrim_):
|
|
621
642
|
def __call__(self, x, output_size, scales, align_corners):
|
|
622
643
|
|
|
623
|
-
return
|
|
644
|
+
return super().__call__([x, output_size, scales, align_corners])
|
|
624
645
|
|
|
625
646
|
|
|
626
647
|
upsample_trilinear3d_impl = _PyboostUpsampleTrilinear3DPrim()
|
|
627
648
|
|
|
628
649
|
|
|
650
|
+
class _PyboostFusedInferAttentionScorePrim(FusedInferAttentionScorePrim_):
|
|
651
|
+
def __call__(self, query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode):
|
|
652
|
+
converted_input_layout = str_to_enum('fused_infer_attention_score', 'input_layout', input_layout)
|
|
653
|
+
return super().__call__([query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode])
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
fused_infer_attention_score_impl = _PyboostFusedInferAttentionScorePrim()
|
|
657
|
+
|
|
658
|
+
|
|
629
659
|
class _PyboostGroupedMatmulPrim(GroupedMatmulPrim_):
|
|
630
|
-
def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type):
|
|
660
|
+
def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b):
|
|
631
661
|
|
|
632
|
-
return
|
|
662
|
+
return super().__call__([x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b])
|
|
633
663
|
|
|
634
664
|
|
|
635
665
|
grouped_matmul_impl = _PyboostGroupedMatmulPrim()
|
|
@@ -638,7 +668,7 @@ grouped_matmul_impl = _PyboostGroupedMatmulPrim()
|
|
|
638
668
|
class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
|
|
639
669
|
def __call__(self, x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype):
|
|
640
670
|
|
|
641
|
-
return
|
|
671
|
+
return super().__call__([x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype])
|
|
642
672
|
|
|
643
673
|
|
|
644
674
|
quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
|
|
@@ -647,7 +677,7 @@ quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
|
|
|
647
677
|
class _PyboostWeightQuantBatchMatmulPrim(WeightQuantBatchMatmulPrim_):
|
|
648
678
|
def __call__(self, x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size):
|
|
649
679
|
|
|
650
|
-
return
|
|
680
|
+
return super().__call__([x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size])
|
|
651
681
|
|
|
652
682
|
|
|
653
683
|
weight_quant_batch_matmul_impl = _PyboostWeightQuantBatchMatmulPrim()
|
|
@@ -25,7 +25,7 @@ from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, Mu
|
|
|
25
25
|
from mindspore.ops.composite.env_ops import env_get
|
|
26
26
|
from mindspore.ops.function.clip_func import clip_by_global_norm
|
|
27
27
|
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
|
28
|
-
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
|
|
28
|
+
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like, _ones_like_for_grad
|
|
29
29
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
30
30
|
from mindspore.ops.function.random_func import normal, laplace, uniform, gamma, poisson, multinomial
|
|
31
31
|
from mindspore.ops.composite.math_ops import matmul, cummin, mm
|
|
@@ -46,6 +46,7 @@ __all__ = [
|
|
|
46
46
|
'hyper_add',
|
|
47
47
|
'zeros_like',
|
|
48
48
|
'ones_like',
|
|
49
|
+
'_ones_like_for_grad',
|
|
49
50
|
'zip_operation',
|
|
50
51
|
'normal',
|
|
51
52
|
'laplace',
|