mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,7 +14,8 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from mindspore.common._stub_tensor import _convert_stub
|
|
17
|
-
from mindspore.ops.
|
|
17
|
+
from mindspore.ops._utils.arg_handler import *
|
|
18
|
+
from mindspore._c_expression import AdaptiveMaxPool2DPrim_
|
|
18
19
|
from mindspore._c_expression import ArgMaxWithValuePrim_
|
|
19
20
|
from mindspore._c_expression import ArgMinWithValuePrim_
|
|
20
21
|
from mindspore._c_expression import BatchMatMulPrim_
|
|
@@ -24,14 +25,14 @@ from mindspore._c_expression import BinaryCrossEntropyPrim_
|
|
|
24
25
|
from mindspore._c_expression import BCEWithLogitsLossPrim_
|
|
25
26
|
from mindspore._c_expression import BroadcastToPrim_
|
|
26
27
|
from mindspore._c_expression import ConcatPrim_
|
|
27
|
-
from mindspore._c_expression import ConvolutionGradPrim_
|
|
28
|
-
from mindspore._c_expression import ConvolutionPrim_
|
|
29
28
|
from mindspore._c_expression import CrossPrim_
|
|
30
29
|
from mindspore._c_expression import CummaxPrim_
|
|
31
30
|
from mindspore._c_expression import EluExtPrim_
|
|
32
31
|
from mindspore._c_expression import FFNExtPrim_
|
|
33
32
|
from mindspore._c_expression import FlashAttentionScoreGradPrim_
|
|
34
33
|
from mindspore._c_expression import FlashAttentionScorePrim_
|
|
34
|
+
from mindspore._c_expression import GluGradPrim_
|
|
35
|
+
from mindspore._c_expression import GLUPrim_
|
|
35
36
|
from mindspore._c_expression import GridSampler2DGradPrim_
|
|
36
37
|
from mindspore._c_expression import GridSampler2DPrim_
|
|
37
38
|
from mindspore._c_expression import GridSampler3DGradPrim_
|
|
@@ -47,31 +48,53 @@ from mindspore._c_expression import MaxPoolGradWithIndicesPrim_
|
|
|
47
48
|
from mindspore._c_expression import MaxPoolGradWithMaskPrim_
|
|
48
49
|
from mindspore._c_expression import MaxPoolWithIndicesPrim_
|
|
49
50
|
from mindspore._c_expression import MaxPoolWithMaskPrim_
|
|
51
|
+
from mindspore._c_expression import MeshgridPrim_
|
|
50
52
|
from mindspore._c_expression import NanToNumPrim_
|
|
53
|
+
from mindspore._c_expression import NLLLossGradPrim_
|
|
54
|
+
from mindspore._c_expression import NLLLossPrim_
|
|
51
55
|
from mindspore._c_expression import OneHotExtPrim_
|
|
56
|
+
from mindspore._c_expression import PromptFlashAttentionPrim_
|
|
52
57
|
from mindspore._c_expression import ReduceAllPrim_
|
|
53
58
|
from mindspore._c_expression import ReduceAnyPrim_
|
|
59
|
+
from mindspore._c_expression import ReduceMaxPrim_
|
|
60
|
+
from mindspore._c_expression import ReduceMinPrim_
|
|
54
61
|
from mindspore._c_expression import ReverseV2Prim_
|
|
55
62
|
from mindspore._c_expression import RmsNormPrim_
|
|
56
63
|
from mindspore._c_expression import RollPrim_
|
|
57
64
|
from mindspore._c_expression import SearchSortedPrim_
|
|
65
|
+
from mindspore._c_expression import SmoothL1LossGradPrim_
|
|
66
|
+
from mindspore._c_expression import SmoothL1LossPrim_
|
|
58
67
|
from mindspore._c_expression import SoftmaxPrim_
|
|
59
68
|
from mindspore._c_expression import SoftShrinkGradPrim_
|
|
60
69
|
from mindspore._c_expression import SoftShrinkPrim_
|
|
70
|
+
from mindspore._c_expression import SoftMarginLossGradPrim_
|
|
71
|
+
from mindspore._c_expression import SoftMarginLossPrim_
|
|
72
|
+
from mindspore._c_expression import SplitPrim_
|
|
73
|
+
from mindspore._c_expression import SqueezePrim_
|
|
61
74
|
from mindspore._c_expression import StackExtPrim_
|
|
62
|
-
from mindspore._c_expression import TrilExtPrim_
|
|
63
75
|
from mindspore._c_expression import TriuPrim_
|
|
76
|
+
from mindspore._c_expression import UniqueConsecutivePrim_
|
|
64
77
|
from mindspore._c_expression import UpsampleTrilinear3DGradPrim_
|
|
65
78
|
from mindspore._c_expression import UpsampleTrilinear3DPrim_
|
|
79
|
+
from mindspore._c_expression import FusedInferAttentionScorePrim_
|
|
66
80
|
from mindspore._c_expression import GroupedMatmulPrim_
|
|
67
81
|
from mindspore._c_expression import QuantBatchMatmulPrim_
|
|
68
82
|
from mindspore._c_expression import WeightQuantBatchMatmulPrim_
|
|
69
83
|
|
|
70
84
|
|
|
85
|
+
class _PyboostAdaptiveMaxPool2DPrim(AdaptiveMaxPool2DPrim_):
|
|
86
|
+
def __call__(self, input, output_size):
|
|
87
|
+
|
|
88
|
+
return super().__call__([input, output_size])
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
adaptive_max_pool2d_impl = _PyboostAdaptiveMaxPool2DPrim()
|
|
92
|
+
|
|
93
|
+
|
|
71
94
|
class _PyboostArgMaxWithValuePrim(ArgMaxWithValuePrim_):
|
|
72
95
|
def __call__(self, input, axis, keep_dims):
|
|
73
96
|
|
|
74
|
-
return
|
|
97
|
+
return super().__call__([input, axis, keep_dims])
|
|
75
98
|
|
|
76
99
|
|
|
77
100
|
argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
|
|
@@ -80,7 +103,7 @@ argmax_with_value_impl = _PyboostArgMaxWithValuePrim()
|
|
|
80
103
|
class _PyboostArgMinWithValuePrim(ArgMinWithValuePrim_):
|
|
81
104
|
def __call__(self, input, axis, keep_dims):
|
|
82
105
|
|
|
83
|
-
return
|
|
106
|
+
return super().__call__([input, axis, keep_dims])
|
|
84
107
|
|
|
85
108
|
|
|
86
109
|
argmin_with_value_impl = _PyboostArgMinWithValuePrim()
|
|
@@ -89,16 +112,16 @@ argmin_with_value_impl = _PyboostArgMinWithValuePrim()
|
|
|
89
112
|
class _PyboostBatchMatMulPrim(BatchMatMulPrim_):
|
|
90
113
|
def __call__(self, x, y, transpose_a, transpose_b):
|
|
91
114
|
|
|
92
|
-
return
|
|
115
|
+
return super().__call__([x, y, transpose_a, transpose_b])
|
|
93
116
|
|
|
94
117
|
|
|
95
118
|
batch_mat_mul_impl = _PyboostBatchMatMulPrim()
|
|
96
119
|
|
|
97
120
|
|
|
98
121
|
class _PyboostBatchNormGradExtPrim(BatchNormGradExtPrim_):
|
|
99
|
-
def __call__(self, dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps):
|
|
122
|
+
def __call__(self, dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask):
|
|
100
123
|
|
|
101
|
-
return
|
|
124
|
+
return super().__call__([dout, input, weight, running_mean, running_var, saved_mean, saved_rstd, training, eps, output_mask])
|
|
102
125
|
|
|
103
126
|
|
|
104
127
|
batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
|
|
@@ -107,7 +130,7 @@ batch_norm_grad_ext_impl = _PyboostBatchNormGradExtPrim()
|
|
|
107
130
|
class _PyboostBinaryCrossEntropyGradPrim(BinaryCrossEntropyGradPrim_):
|
|
108
131
|
def __call__(self, input, target, grad_output, weight, reduction):
|
|
109
132
|
converted_reduction = str_to_enum('binary_cross_entropy_grad', 'reduction', reduction)
|
|
110
|
-
return
|
|
133
|
+
return super().__call__([input, target, grad_output, weight, converted_reduction])
|
|
111
134
|
|
|
112
135
|
|
|
113
136
|
binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
|
|
@@ -116,7 +139,7 @@ binary_cross_entropy_grad_impl = _PyboostBinaryCrossEntropyGradPrim()
|
|
|
116
139
|
class _PyboostBinaryCrossEntropyPrim(BinaryCrossEntropyPrim_):
|
|
117
140
|
def __call__(self, input, target, weight, reduction):
|
|
118
141
|
converted_reduction = str_to_enum('binary_cross_entropy', 'reduction', reduction)
|
|
119
|
-
return
|
|
142
|
+
return super().__call__([input, target, weight, converted_reduction])
|
|
120
143
|
|
|
121
144
|
|
|
122
145
|
binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
|
|
@@ -125,7 +148,7 @@ binary_cross_entropy_impl = _PyboostBinaryCrossEntropyPrim()
|
|
|
125
148
|
class _PyboostBCEWithLogitsLossPrim(BCEWithLogitsLossPrim_):
|
|
126
149
|
def __call__(self, input, target, weight, posWeight, reduction):
|
|
127
150
|
converted_reduction = str_to_enum('binary_cross_entropy_with_logits', 'reduction', reduction)
|
|
128
|
-
return
|
|
151
|
+
return super().__call__([input, target, weight, posWeight, converted_reduction])
|
|
129
152
|
|
|
130
153
|
|
|
131
154
|
binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
|
|
@@ -134,7 +157,7 @@ binary_cross_entropy_with_logits_impl = _PyboostBCEWithLogitsLossPrim()
|
|
|
134
157
|
class _PyboostBroadcastToPrim(BroadcastToPrim_):
|
|
135
158
|
def __call__(self, input, shape):
|
|
136
159
|
|
|
137
|
-
return
|
|
160
|
+
return super().__call__([input, shape])
|
|
138
161
|
|
|
139
162
|
|
|
140
163
|
broadcast_to_impl = _PyboostBroadcastToPrim()
|
|
@@ -143,40 +166,16 @@ broadcast_to_impl = _PyboostBroadcastToPrim()
|
|
|
143
166
|
class _PyboostConcatPrim(ConcatPrim_):
|
|
144
167
|
def __call__(self, tensors, axis):
|
|
145
168
|
|
|
146
|
-
return
|
|
169
|
+
return super().__call__([tensors, axis])
|
|
147
170
|
|
|
148
171
|
|
|
149
172
|
concat_impl = _PyboostConcatPrim()
|
|
150
173
|
|
|
151
174
|
|
|
152
|
-
class _PyboostConvolutionGradPrim(ConvolutionGradPrim_):
|
|
153
|
-
def __call__(self, dout, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, output_mask):
|
|
154
|
-
converted_stride = to_strides('convolution_grad', 'stride', stride)
|
|
155
|
-
converted_padding = to_2d_paddings('convolution_grad', 'padding', padding)
|
|
156
|
-
converted_dilation = to_dilations('convolution_grad', 'dilation', dilation)
|
|
157
|
-
converted_output_padding = to_output_padding('convolution_grad', 'output_padding', output_padding)
|
|
158
|
-
return _convert_stub(super().__call__(dout, input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups, output_mask))
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
convolution_grad_impl = _PyboostConvolutionGradPrim()
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
class _PyboostConvolutionPrim(ConvolutionPrim_):
|
|
165
|
-
def __call__(self, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
|
|
166
|
-
converted_stride = to_strides('convolution', 'stride', stride)
|
|
167
|
-
converted_padding = to_2d_paddings('convolution', 'padding', padding)
|
|
168
|
-
converted_dilation = to_dilations('convolution', 'dilation', dilation)
|
|
169
|
-
converted_output_padding = to_output_padding('convolution', 'output_padding', output_padding)
|
|
170
|
-
return _convert_stub(super().__call__(input, weight, bias, converted_stride, converted_padding, converted_dilation, transposed, converted_output_padding, groups))
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
convolution_impl = _PyboostConvolutionPrim()
|
|
174
|
-
|
|
175
|
-
|
|
176
175
|
class _PyboostCrossPrim(CrossPrim_):
|
|
177
176
|
def __call__(self, input, other, dim):
|
|
178
177
|
|
|
179
|
-
return
|
|
178
|
+
return super().__call__([input, other, dim])
|
|
180
179
|
|
|
181
180
|
|
|
182
181
|
cross_impl = _PyboostCrossPrim()
|
|
@@ -185,7 +184,7 @@ cross_impl = _PyboostCrossPrim()
|
|
|
185
184
|
class _PyboostCummaxPrim(CummaxPrim_):
|
|
186
185
|
def __call__(self, input, axis):
|
|
187
186
|
|
|
188
|
-
return
|
|
187
|
+
return super().__call__([input, axis])
|
|
189
188
|
|
|
190
189
|
|
|
191
190
|
cummax_impl = _PyboostCummaxPrim()
|
|
@@ -194,7 +193,7 @@ cummax_impl = _PyboostCummaxPrim()
|
|
|
194
193
|
class _PyboostEluExtPrim(EluExtPrim_):
|
|
195
194
|
def __call__(self, input, alpha):
|
|
196
195
|
|
|
197
|
-
return
|
|
196
|
+
return super().__call__([input, alpha])
|
|
198
197
|
|
|
199
198
|
|
|
200
199
|
elu_ext_impl = _PyboostEluExtPrim()
|
|
@@ -203,7 +202,7 @@ elu_ext_impl = _PyboostEluExtPrim()
|
|
|
203
202
|
class _PyboostFFNExtPrim(FFNExtPrim_):
|
|
204
203
|
def __call__(self, x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, activation, inner_precise):
|
|
205
204
|
converted_activation = str_to_enum('ffn_ext', 'activation', activation)
|
|
206
|
-
return
|
|
205
|
+
return super().__call__([x, weight1, weight2, expertTokens, bias1, bias2, scale, offset, deqScale1, deqScale2, antiquant_scale1, antiquant_scale2, antiquant_offset1, antiquant_offset2, converted_activation, inner_precise])
|
|
207
206
|
|
|
208
207
|
|
|
209
208
|
ffn_ext_impl = _PyboostFFNExtPrim()
|
|
@@ -212,7 +211,7 @@ ffn_ext_impl = _PyboostFFNExtPrim()
|
|
|
212
211
|
class _PyboostFlashAttentionScoreGradPrim(FlashAttentionScoreGradPrim_):
|
|
213
212
|
def __call__(self, query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
|
|
214
213
|
converted_input_layout = str_to_enum('flash_attention_score_grad', 'input_layout', input_layout)
|
|
215
|
-
return
|
|
214
|
+
return super().__call__([query, key, value, dy, pse_shift, drop_mask, padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in, attention_in, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
|
|
216
215
|
|
|
217
216
|
|
|
218
217
|
flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
|
|
@@ -221,17 +220,35 @@ flash_attention_score_grad_impl = _PyboostFlashAttentionScoreGradPrim()
|
|
|
221
220
|
class _PyboostFlashAttentionScorePrim(FlashAttentionScorePrim_):
|
|
222
221
|
def __call__(self, query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, input_layout, sparse_mode):
|
|
223
222
|
converted_input_layout = str_to_enum('flash_attention_score', 'input_layout', input_layout)
|
|
224
|
-
return
|
|
223
|
+
return super().__call__([query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, actual_seq_kvlen, head_num, keep_prob, scale_value, pre_tokens, next_tokens, inner_precise, converted_input_layout, sparse_mode])
|
|
225
224
|
|
|
226
225
|
|
|
227
226
|
flash_attention_score_impl = _PyboostFlashAttentionScorePrim()
|
|
228
227
|
|
|
229
228
|
|
|
229
|
+
class _PyboostGluGradPrim(GluGradPrim_):
|
|
230
|
+
def __call__(self, grads, x, axis):
|
|
231
|
+
|
|
232
|
+
return super().__call__([grads, x, axis])
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
glu_grad_impl = _PyboostGluGradPrim()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class _PyboostGLUPrim(GLUPrim_):
|
|
239
|
+
def __call__(self, x, axis):
|
|
240
|
+
|
|
241
|
+
return super().__call__([x, axis])
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
glu_impl = _PyboostGLUPrim()
|
|
245
|
+
|
|
246
|
+
|
|
230
247
|
class _PyboostGridSampler2DGradPrim(GridSampler2DGradPrim_):
|
|
231
|
-
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
248
|
+
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
|
|
232
249
|
converted_interpolation_mode = str_to_enum('grid_sampler_2d_grad', 'interpolation_mode', interpolation_mode)
|
|
233
250
|
converted_padding_mode = str_to_enum('grid_sampler_2d_grad', 'padding_mode', padding_mode)
|
|
234
|
-
return
|
|
251
|
+
return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
|
|
235
252
|
|
|
236
253
|
|
|
237
254
|
grid_sampler_2d_grad_impl = _PyboostGridSampler2DGradPrim()
|
|
@@ -241,17 +258,17 @@ class _PyboostGridSampler2DPrim(GridSampler2DPrim_):
|
|
|
241
258
|
def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
242
259
|
converted_interpolation_mode = str_to_enum('grid_sampler_2d', 'interpolation_mode', interpolation_mode)
|
|
243
260
|
converted_padding_mode = str_to_enum('grid_sampler_2d', 'padding_mode', padding_mode)
|
|
244
|
-
return
|
|
261
|
+
return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
|
|
245
262
|
|
|
246
263
|
|
|
247
264
|
grid_sampler_2d_impl = _PyboostGridSampler2DPrim()
|
|
248
265
|
|
|
249
266
|
|
|
250
267
|
class _PyboostGridSampler3DGradPrim(GridSampler3DGradPrim_):
|
|
251
|
-
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
268
|
+
def __call__(self, grad, input_x, grid, interpolation_mode, padding_mode, align_corners, output_mask):
|
|
252
269
|
converted_interpolation_mode = str_to_enum('grid_sampler_3d_grad', 'interpolation_mode', interpolation_mode)
|
|
253
270
|
converted_padding_mode = str_to_enum('grid_sampler_3d_grad', 'padding_mode', padding_mode)
|
|
254
|
-
return
|
|
271
|
+
return super().__call__([grad, input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners, output_mask])
|
|
255
272
|
|
|
256
273
|
|
|
257
274
|
grid_sampler_3d_grad_impl = _PyboostGridSampler3DGradPrim()
|
|
@@ -261,7 +278,7 @@ class _PyboostGridSampler3DPrim(GridSampler3DPrim_):
|
|
|
261
278
|
def __call__(self, input_x, grid, interpolation_mode, padding_mode, align_corners):
|
|
262
279
|
converted_interpolation_mode = str_to_enum('grid_sampler_3d', 'interpolation_mode', interpolation_mode)
|
|
263
280
|
converted_padding_mode = str_to_enum('grid_sampler_3d', 'padding_mode', padding_mode)
|
|
264
|
-
return
|
|
281
|
+
return super().__call__([input_x, grid, converted_interpolation_mode, converted_padding_mode, align_corners])
|
|
265
282
|
|
|
266
283
|
|
|
267
284
|
grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
|
|
@@ -270,7 +287,7 @@ grid_sampler_3d_impl = _PyboostGridSampler3DPrim()
|
|
|
270
287
|
class _PyboostHShrinkGradPrim(HShrinkGradPrim_):
|
|
271
288
|
def __call__(self, gradients, features, lambd):
|
|
272
289
|
|
|
273
|
-
return
|
|
290
|
+
return super().__call__([gradients, features, lambd])
|
|
274
291
|
|
|
275
292
|
|
|
276
293
|
hshrink_grad_impl = _PyboostHShrinkGradPrim()
|
|
@@ -279,7 +296,7 @@ hshrink_grad_impl = _PyboostHShrinkGradPrim()
|
|
|
279
296
|
class _PyboostHShrinkPrim(HShrinkPrim_):
|
|
280
297
|
def __call__(self, input, lambd):
|
|
281
298
|
|
|
282
|
-
return
|
|
299
|
+
return super().__call__([input, lambd])
|
|
283
300
|
|
|
284
301
|
|
|
285
302
|
hshrink_impl = _PyboostHShrinkPrim()
|
|
@@ -288,7 +305,7 @@ hshrink_impl = _PyboostHShrinkPrim()
|
|
|
288
305
|
class _PyboostIncreFlashAttentionPrim(IncreFlashAttentionPrim_):
|
|
289
306
|
def __call__(self, query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, input_layout, scale_value, num_key_value_heads, block_size, inner_precise):
|
|
290
307
|
converted_input_layout = str_to_enum('incre_flash_attention', 'input_layout', input_layout)
|
|
291
|
-
return
|
|
308
|
+
return super().__call__([query, key, value, attn_mask, actual_seq_lengths, pse_shift, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, kv_padding_size, num_heads, converted_input_layout, scale_value, num_key_value_heads, block_size, inner_precise])
|
|
292
309
|
|
|
293
310
|
|
|
294
311
|
incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
|
|
@@ -297,7 +314,7 @@ incre_flash_attention_impl = _PyboostIncreFlashAttentionPrim()
|
|
|
297
314
|
class _PyboostIsClosePrim(IsClosePrim_):
|
|
298
315
|
def __call__(self, input, other, rtol, atol, equal_nan):
|
|
299
316
|
|
|
300
|
-
return
|
|
317
|
+
return super().__call__([input, other, rtol, atol, equal_nan])
|
|
301
318
|
|
|
302
319
|
|
|
303
320
|
isclose_impl = _PyboostIsClosePrim()
|
|
@@ -306,7 +323,7 @@ isclose_impl = _PyboostIsClosePrim()
|
|
|
306
323
|
class _PyboostLogSoftmaxGradPrim(LogSoftmaxGradPrim_):
|
|
307
324
|
def __call__(self, logits, grad, axis):
|
|
308
325
|
|
|
309
|
-
return
|
|
326
|
+
return super().__call__([logits, grad, axis])
|
|
310
327
|
|
|
311
328
|
|
|
312
329
|
log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
|
|
@@ -315,7 +332,7 @@ log_softmax_grad_impl = _PyboostLogSoftmaxGradPrim()
|
|
|
315
332
|
class _PyboostLogSoftmaxPrim(LogSoftmaxPrim_):
|
|
316
333
|
def __call__(self, logits, axis):
|
|
317
334
|
|
|
318
|
-
return
|
|
335
|
+
return super().__call__([logits, axis])
|
|
319
336
|
|
|
320
337
|
|
|
321
338
|
log_softmax_impl = _PyboostLogSoftmaxPrim()
|
|
@@ -324,7 +341,7 @@ log_softmax_impl = _PyboostLogSoftmaxPrim()
|
|
|
324
341
|
class _PyboostMatMulPrim(MatMulPrim_):
|
|
325
342
|
def __call__(self, input, mat2, transpose_a, transpose_b):
|
|
326
343
|
|
|
327
|
-
return
|
|
344
|
+
return super().__call__([input, mat2, transpose_a, transpose_b])
|
|
328
345
|
|
|
329
346
|
|
|
330
347
|
matmul_impl = _PyboostMatMulPrim()
|
|
@@ -336,7 +353,7 @@ class _PyboostMaxPoolGradWithIndicesPrim(MaxPoolGradWithIndicesPrim_):
|
|
|
336
353
|
converted_strides = to_strides('max_pool_grad_with_indices', 'strides', strides)
|
|
337
354
|
converted_pads = to_output_padding('max_pool_grad_with_indices', 'pads', pads)
|
|
338
355
|
converted_dilation = to_dilations('max_pool_grad_with_indices', 'dilation', dilation)
|
|
339
|
-
return
|
|
356
|
+
return super().__call__([x, grad, argmax, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
340
357
|
|
|
341
358
|
|
|
342
359
|
max_pool_grad_with_indices_impl = _PyboostMaxPoolGradWithIndicesPrim()
|
|
@@ -348,7 +365,7 @@ class _PyboostMaxPoolGradWithMaskPrim(MaxPoolGradWithMaskPrim_):
|
|
|
348
365
|
converted_strides = to_strides('max_pool_grad_with_mask', 'strides', strides)
|
|
349
366
|
converted_pads = to_output_padding('max_pool_grad_with_mask', 'pads', pads)
|
|
350
367
|
converted_dilation = to_dilations('max_pool_grad_with_mask', 'dilation', dilation)
|
|
351
|
-
return
|
|
368
|
+
return super().__call__([x, grad, mask, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
352
369
|
|
|
353
370
|
|
|
354
371
|
max_pool_grad_with_mask_impl = _PyboostMaxPoolGradWithMaskPrim()
|
|
@@ -360,7 +377,7 @@ class _PyboostMaxPoolWithIndicesPrim(MaxPoolWithIndicesPrim_):
|
|
|
360
377
|
converted_strides = to_strides('max_pool_with_indices', 'strides', strides)
|
|
361
378
|
converted_pads = to_output_padding('max_pool_with_indices', 'pads', pads)
|
|
362
379
|
converted_dilation = to_dilations('max_pool_with_indices', 'dilation', dilation)
|
|
363
|
-
return
|
|
380
|
+
return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
364
381
|
|
|
365
382
|
|
|
366
383
|
max_pool_with_indices_impl = _PyboostMaxPoolWithIndicesPrim()
|
|
@@ -372,34 +389,70 @@ class _PyboostMaxPoolWithMaskPrim(MaxPoolWithMaskPrim_):
|
|
|
372
389
|
converted_strides = to_strides('max_pool_with_mask', 'strides', strides)
|
|
373
390
|
converted_pads = to_output_padding('max_pool_with_mask', 'pads', pads)
|
|
374
391
|
converted_dilation = to_dilations('max_pool_with_mask', 'dilation', dilation)
|
|
375
|
-
return
|
|
392
|
+
return super().__call__([x, converted_kernel_size, converted_strides, converted_pads, converted_dilation, ceil_mode, argmax_type])
|
|
376
393
|
|
|
377
394
|
|
|
378
395
|
max_pool_with_mask_impl = _PyboostMaxPoolWithMaskPrim()
|
|
379
396
|
|
|
380
397
|
|
|
398
|
+
class _PyboostMeshgridPrim(MeshgridPrim_):
|
|
399
|
+
def __call__(self, inputs, indexing):
|
|
400
|
+
converted_indexing = str_to_enum('meshgrid', 'indexing', indexing)
|
|
401
|
+
return super().__call__([inputs, converted_indexing])
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
meshgrid_impl = _PyboostMeshgridPrim()
|
|
405
|
+
|
|
406
|
+
|
|
381
407
|
class _PyboostNanToNumPrim(NanToNumPrim_):
|
|
382
408
|
def __call__(self, input, nan, posinf, neginf):
|
|
383
409
|
|
|
384
|
-
return
|
|
410
|
+
return super().__call__([input, nan, posinf, neginf])
|
|
385
411
|
|
|
386
412
|
|
|
387
413
|
nan_to_num_impl = _PyboostNanToNumPrim()
|
|
388
414
|
|
|
389
415
|
|
|
416
|
+
class _PyboostNLLLossGradPrim(NLLLossGradPrim_):
|
|
417
|
+
def __call__(self, logits, loss_grad, labels, weight, total_weight, reduction, ignore_index):
|
|
418
|
+
converted_reduction = str_to_enum('nllloss_grad', 'reduction', reduction)
|
|
419
|
+
return super().__call__([logits, loss_grad, labels, weight, total_weight, converted_reduction, ignore_index])
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
nllloss_grad_impl = _PyboostNLLLossGradPrim()
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class _PyboostNLLLossPrim(NLLLossPrim_):
|
|
426
|
+
def __call__(self, logits, labels, weight, reduction, ignore_index):
|
|
427
|
+
converted_reduction = str_to_enum('nllloss', 'reduction', reduction)
|
|
428
|
+
return super().__call__([logits, labels, weight, converted_reduction, ignore_index])
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
nllloss_impl = _PyboostNLLLossPrim()
|
|
432
|
+
|
|
433
|
+
|
|
390
434
|
class _PyboostOneHotExtPrim(OneHotExtPrim_):
|
|
391
435
|
def __call__(self, tensor, num_classes, on_value, off_value, axis):
|
|
392
436
|
|
|
393
|
-
return
|
|
437
|
+
return super().__call__([tensor, num_classes, on_value, off_value, axis])
|
|
394
438
|
|
|
395
439
|
|
|
396
440
|
one_hot_ext_impl = _PyboostOneHotExtPrim()
|
|
397
441
|
|
|
398
442
|
|
|
443
|
+
class _PyboostPromptFlashAttentionPrim(PromptFlashAttentionPrim_):
|
|
444
|
+
def __call__(self, query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise):
|
|
445
|
+
converted_input_layout = str_to_enum('prompt_flash_attention', 'input_layout', input_layout)
|
|
446
|
+
return super().__call__([query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise])
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
prompt_flash_attention_impl = _PyboostPromptFlashAttentionPrim()
|
|
450
|
+
|
|
451
|
+
|
|
399
452
|
class _PyboostReduceAllPrim(ReduceAllPrim_):
|
|
400
453
|
def __call__(self, input, axis, keep_dims):
|
|
401
454
|
|
|
402
|
-
return
|
|
455
|
+
return super().__call__([input, axis, keep_dims])
|
|
403
456
|
|
|
404
457
|
|
|
405
458
|
reduce_all_impl = _PyboostReduceAllPrim()
|
|
@@ -408,16 +461,34 @@ reduce_all_impl = _PyboostReduceAllPrim()
|
|
|
408
461
|
class _PyboostReduceAnyPrim(ReduceAnyPrim_):
|
|
409
462
|
def __call__(self, x, axis, keep_dims):
|
|
410
463
|
|
|
411
|
-
return
|
|
464
|
+
return super().__call__([x, axis, keep_dims])
|
|
412
465
|
|
|
413
466
|
|
|
414
467
|
reduce_any_impl = _PyboostReduceAnyPrim()
|
|
415
468
|
|
|
416
469
|
|
|
470
|
+
class _PyboostReduceMaxPrim(ReduceMaxPrim_):
|
|
471
|
+
def __call__(self, x, axis, keep_dims):
|
|
472
|
+
|
|
473
|
+
return super().__call__([x, axis, keep_dims])
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
reduce_max_impl = _PyboostReduceMaxPrim()
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
class _PyboostReduceMinPrim(ReduceMinPrim_):
|
|
480
|
+
def __call__(self, x, axis, keep_dims):
|
|
481
|
+
|
|
482
|
+
return super().__call__([x, axis, keep_dims])
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
reduce_min_impl = _PyboostReduceMinPrim()
|
|
486
|
+
|
|
487
|
+
|
|
417
488
|
class _PyboostReverseV2Prim(ReverseV2Prim_):
|
|
418
489
|
def __call__(self, input, axis):
|
|
419
490
|
|
|
420
|
-
return
|
|
491
|
+
return super().__call__([input, axis])
|
|
421
492
|
|
|
422
493
|
|
|
423
494
|
reverse_v2_impl = _PyboostReverseV2Prim()
|
|
@@ -426,16 +497,16 @@ reverse_v2_impl = _PyboostReverseV2Prim()
|
|
|
426
497
|
class _PyboostRmsNormPrim(RmsNormPrim_):
|
|
427
498
|
def __call__(self, x, gamma, epsilon):
|
|
428
499
|
|
|
429
|
-
return
|
|
500
|
+
return super().__call__([x, gamma, epsilon])
|
|
430
501
|
|
|
431
502
|
|
|
432
503
|
rms_norm_impl = _PyboostRmsNormPrim()
|
|
433
504
|
|
|
434
505
|
|
|
435
506
|
class _PyboostRollPrim(RollPrim_):
|
|
436
|
-
def __call__(self, input,
|
|
507
|
+
def __call__(self, input, shifts, dims):
|
|
437
508
|
|
|
438
|
-
return
|
|
509
|
+
return super().__call__([input, shifts, dims])
|
|
439
510
|
|
|
440
511
|
|
|
441
512
|
roll_impl = _PyboostRollPrim()
|
|
@@ -444,16 +515,34 @@ roll_impl = _PyboostRollPrim()
|
|
|
444
515
|
class _PyboostSearchSortedPrim(SearchSortedPrim_):
|
|
445
516
|
def __call__(self, sorted_sequence, values, sorter, dtype, right):
|
|
446
517
|
|
|
447
|
-
return
|
|
518
|
+
return super().__call__([sorted_sequence, values, sorter, dtype, right])
|
|
448
519
|
|
|
449
520
|
|
|
450
521
|
searchsorted_impl = _PyboostSearchSortedPrim()
|
|
451
522
|
|
|
452
523
|
|
|
524
|
+
class _PyboostSmoothL1LossGradPrim(SmoothL1LossGradPrim_):
|
|
525
|
+
def __call__(self, prediction, target, dout, beta, reduction):
|
|
526
|
+
converted_reduction = str_to_enum('smooth_l1_loss_grad', 'reduction', reduction)
|
|
527
|
+
return super().__call__([prediction, target, dout, beta, converted_reduction])
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
smooth_l1_loss_grad_impl = _PyboostSmoothL1LossGradPrim()
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class _PyboostSmoothL1LossPrim(SmoothL1LossPrim_):
|
|
534
|
+
def __call__(self, prediction, target, beta, reduction):
|
|
535
|
+
converted_reduction = str_to_enum('smooth_l1_loss', 'reduction', reduction)
|
|
536
|
+
return super().__call__([prediction, target, beta, converted_reduction])
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
smooth_l1_loss_impl = _PyboostSmoothL1LossPrim()
|
|
540
|
+
|
|
541
|
+
|
|
453
542
|
class _PyboostSoftmaxPrim(SoftmaxPrim_):
|
|
454
543
|
def __call__(self, input, axis):
|
|
455
544
|
|
|
456
|
-
return
|
|
545
|
+
return super().__call__([input, axis])
|
|
457
546
|
|
|
458
547
|
|
|
459
548
|
softmax_impl = _PyboostSoftmaxPrim()
|
|
@@ -462,7 +551,7 @@ softmax_impl = _PyboostSoftmaxPrim()
|
|
|
462
551
|
class _PyboostSoftShrinkGradPrim(SoftShrinkGradPrim_):
|
|
463
552
|
def __call__(self, input_grad, input_x, lambd):
|
|
464
553
|
|
|
465
|
-
return
|
|
554
|
+
return super().__call__([input_grad, input_x, lambd])
|
|
466
555
|
|
|
467
556
|
|
|
468
557
|
softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
|
|
@@ -471,43 +560,79 @@ softshrink_grad_impl = _PyboostSoftShrinkGradPrim()
|
|
|
471
560
|
class _PyboostSoftShrinkPrim(SoftShrinkPrim_):
|
|
472
561
|
def __call__(self, input, lambd):
|
|
473
562
|
|
|
474
|
-
return
|
|
563
|
+
return super().__call__([input, lambd])
|
|
475
564
|
|
|
476
565
|
|
|
477
566
|
softshrink_impl = _PyboostSoftShrinkPrim()
|
|
478
567
|
|
|
479
568
|
|
|
569
|
+
class _PyboostSoftMarginLossGradPrim(SoftMarginLossGradPrim_):
|
|
570
|
+
def __call__(self, predict, label, dout, reduction):
|
|
571
|
+
converted_reduction = str_to_enum('soft_margin_loss_grad', 'reduction', reduction)
|
|
572
|
+
return super().__call__([predict, label, dout, converted_reduction])
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
soft_margin_loss_grad_impl = _PyboostSoftMarginLossGradPrim()
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class _PyboostSoftMarginLossPrim(SoftMarginLossPrim_):
|
|
579
|
+
def __call__(self, input, target, reduction):
|
|
580
|
+
converted_reduction = str_to_enum('soft_margin_loss', 'reduction', reduction)
|
|
581
|
+
return super().__call__([input, target, converted_reduction])
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
soft_margin_loss_impl = _PyboostSoftMarginLossPrim()
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class _PyboostSplitPrim(SplitPrim_):
|
|
588
|
+
def __call__(self, input_x, axis, output_num):
|
|
589
|
+
|
|
590
|
+
return super().__call__([input_x, axis, output_num])
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
split_impl = _PyboostSplitPrim()
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class _PyboostSqueezePrim(SqueezePrim_):
|
|
597
|
+
def __call__(self, input, axis):
|
|
598
|
+
|
|
599
|
+
return super().__call__([input, axis])
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
squeeze_impl = _PyboostSqueezePrim()
|
|
603
|
+
|
|
604
|
+
|
|
480
605
|
class _PyboostStackExtPrim(StackExtPrim_):
|
|
481
606
|
def __call__(self, tensors, dim):
|
|
482
607
|
|
|
483
|
-
return
|
|
608
|
+
return super().__call__([tensors, dim])
|
|
484
609
|
|
|
485
610
|
|
|
486
611
|
stack_ext_impl = _PyboostStackExtPrim()
|
|
487
612
|
|
|
488
613
|
|
|
489
|
-
class
|
|
614
|
+
class _PyboostTriuPrim(TriuPrim_):
|
|
490
615
|
def __call__(self, input, diagonal):
|
|
491
616
|
|
|
492
|
-
return
|
|
617
|
+
return super().__call__([input, diagonal])
|
|
493
618
|
|
|
494
619
|
|
|
495
|
-
|
|
620
|
+
triu_impl = _PyboostTriuPrim()
|
|
496
621
|
|
|
497
622
|
|
|
498
|
-
class
|
|
499
|
-
def __call__(self, input,
|
|
623
|
+
class _PyboostUniqueConsecutivePrim(UniqueConsecutivePrim_):
|
|
624
|
+
def __call__(self, input, return_inverse, return_counts, dim):
|
|
500
625
|
|
|
501
|
-
return
|
|
626
|
+
return super().__call__([input, return_inverse, return_counts, dim])
|
|
502
627
|
|
|
503
628
|
|
|
504
|
-
|
|
629
|
+
unique_consecutive_impl = _PyboostUniqueConsecutivePrim()
|
|
505
630
|
|
|
506
631
|
|
|
507
632
|
class _PyboostUpsampleTrilinear3DGradPrim(UpsampleTrilinear3DGradPrim_):
|
|
508
633
|
def __call__(self, dy, input_size, output_size, scales, align_corners):
|
|
509
634
|
|
|
510
|
-
return
|
|
635
|
+
return super().__call__([dy, input_size, output_size, scales, align_corners])
|
|
511
636
|
|
|
512
637
|
|
|
513
638
|
upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
|
|
@@ -516,16 +641,25 @@ upsample_trilinear3d_grad_impl = _PyboostUpsampleTrilinear3DGradPrim()
|
|
|
516
641
|
class _PyboostUpsampleTrilinear3DPrim(UpsampleTrilinear3DPrim_):
|
|
517
642
|
def __call__(self, x, output_size, scales, align_corners):
|
|
518
643
|
|
|
519
|
-
return
|
|
644
|
+
return super().__call__([x, output_size, scales, align_corners])
|
|
520
645
|
|
|
521
646
|
|
|
522
647
|
upsample_trilinear3d_impl = _PyboostUpsampleTrilinear3DPrim()
|
|
523
648
|
|
|
524
649
|
|
|
650
|
+
class _PyboostFusedInferAttentionScorePrim(FusedInferAttentionScorePrim_):
|
|
651
|
+
def __call__(self, query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode):
|
|
652
|
+
converted_input_layout = str_to_enum('fused_infer_attention_score', 'input_layout', input_layout)
|
|
653
|
+
return super().__call__([query, key, value, pse_shift, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, num_heads, scale_value, pre_tokens, next_tokens, converted_input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode])
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
fused_infer_attention_score_impl = _PyboostFusedInferAttentionScorePrim()
|
|
657
|
+
|
|
658
|
+
|
|
525
659
|
class _PyboostGroupedMatmulPrim(GroupedMatmulPrim_):
|
|
526
|
-
def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type):
|
|
660
|
+
def __call__(self, x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b):
|
|
527
661
|
|
|
528
|
-
return
|
|
662
|
+
return super().__call__([x, weight, bias, scale, offset, antiquant_scale, antiquant_offset, group_list, split_item, group_type, transpose_a, transpose_b])
|
|
529
663
|
|
|
530
664
|
|
|
531
665
|
grouped_matmul_impl = _PyboostGroupedMatmulPrim()
|
|
@@ -534,7 +668,7 @@ grouped_matmul_impl = _PyboostGroupedMatmulPrim()
|
|
|
534
668
|
class _PyboostQuantBatchMatmulPrim(QuantBatchMatmulPrim_):
|
|
535
669
|
def __call__(self, x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype):
|
|
536
670
|
|
|
537
|
-
return
|
|
671
|
+
return super().__call__([x1, x2, scale, offset, bias, pertokenScaleOptional, transpose_x1, transpose_x2, dtype])
|
|
538
672
|
|
|
539
673
|
|
|
540
674
|
quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
|
|
@@ -543,7 +677,7 @@ quant_batch_matmul_impl = _PyboostQuantBatchMatmulPrim()
|
|
|
543
677
|
class _PyboostWeightQuantBatchMatmulPrim(WeightQuantBatchMatmulPrim_):
|
|
544
678
|
def __call__(self, x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size):
|
|
545
679
|
|
|
546
|
-
return
|
|
680
|
+
return super().__call__([x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, transpose_x, transpose_weight, antiquant_group_size])
|
|
547
681
|
|
|
548
682
|
|
|
549
683
|
weight_quant_batch_matmul_impl = _PyboostWeightQuantBatchMatmulPrim()
|