mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module defines the PyboostFunctionsGenerator class for generating C++ functions for PyBoost operations.
|
|
17
|
+
|
|
18
|
+
The generator processes operator prototypes and constructs the necessary function definitions, including
|
|
19
|
+
conversions for optional parameters and tensor arguments. It generates the registration code and includes
|
|
20
|
+
the necessary header files for the generated functions.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
|
|
25
|
+
import common.template as template
|
|
26
|
+
import common.gen_constants as K
|
|
27
|
+
from common.template import Template
|
|
28
|
+
from common.gen_utils import save_file
|
|
29
|
+
from common.op_proto import OpProto
|
|
30
|
+
from common.base_generator import BaseGenerator
|
|
31
|
+
from pyboost import pyboost_utils
|
|
32
|
+
from pyboost.pyboost_utils import get_convert_type_str, is_optional_param, is_op_multi_output, get_input_args_type_str, is_tensor_list
|
|
33
|
+
|
|
34
|
+
from .op_template_parser import OpTemplateParser
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class PyboostFunctionsGenerator(BaseGenerator):
|
|
38
|
+
"""
|
|
39
|
+
Generates PyBoost functions based on operator prototypes.
|
|
40
|
+
|
|
41
|
+
This class processes operator prototypes (`op_protos`) to create the necessary C++ function definitions for
|
|
42
|
+
PyBoost operations. It constructs function bodies, handles optional value conversions, and generates
|
|
43
|
+
registration code and header inclusions.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self):
|
|
47
|
+
"""Initializes the PyboostFunctionsGenerator with the necessary templates."""
|
|
48
|
+
self.pyboost_func_include_header_template = Template(
|
|
49
|
+
f'#include "{K.MS_PYBOOST_BASE_PATH}/auto_generate/${{operator_name}}.h"\n'
|
|
50
|
+
)
|
|
51
|
+
self.convert_optional_to_value_template = Template(
|
|
52
|
+
"auto ${output} = PyNativeAlgo::PyBoost::OptionalToValue(${input});\n"
|
|
53
|
+
)
|
|
54
|
+
self.convert_to_tensor_template = Template(
|
|
55
|
+
'auto ${output} = PyNativeAlgo::Common::ConvertStubNodeToTensor(${input}, ${need_contiguous}, '
|
|
56
|
+
'op_run_info->requires_grad);\n'
|
|
57
|
+
)
|
|
58
|
+
self.convert_to_tensor_list_template = Template(
|
|
59
|
+
'auto ${output} = PyNativeAlgo::Common::ConvertStubNodeToValueTuple(${input}, ${need_contiguous}, '
|
|
60
|
+
'op_run_info->requires_grad);\n'
|
|
61
|
+
)
|
|
62
|
+
self.convert_template = Template("auto $arg_name = converter.${convert_func}(args, $arg_index);\n")
|
|
63
|
+
self.input_args_template = Template(" const ${arg_type}& ${arg_name},")
|
|
64
|
+
self.PYBOOST_FUNCTION_TEMPLATE = template.PYBOOST_FUNCTION_TEMPLATE
|
|
65
|
+
self.PYBOOST_COMM_FUNCTION_TEMPLATE = template.PYBOOST_COMM_FUNCTION_TEMPLATE
|
|
66
|
+
self.PYBOOST_FUNCTION_DYNAMIC_OUTPUT_TEMPLATE = template.PYBOOST_FUNCTION_DYNAMIC_OUTPUT_TEMPLATE
|
|
67
|
+
self.REGISTER_DEFINE_TEMPLATE = template.REGISTER_DEFINE_TEMPLATE
|
|
68
|
+
self.REGISTER_TEMPLATE = template.REGISTER_TEMPLATE
|
|
69
|
+
self.PYBOOST_HEADER_TEMPLATE = template.PYBOOST_FUNCTIONS_CC_TEMPLATE
|
|
70
|
+
self.TENSOR_FUNC_CLASS_REG = template.TENSOR_FUNC_CLASS_REG
|
|
71
|
+
self.OP_DEF_INC_HEAD_TEMPLATE = template.OP_DEF_INC_HEAD_TEMPLATE
|
|
72
|
+
|
|
73
|
+
def generate(self, work_path, op_protos):
|
|
74
|
+
"""
|
|
75
|
+
Generates the C++ PyBoost functions and writes them to the specified files.
|
|
76
|
+
|
|
77
|
+
This method processes a list of operator prototypes (`op_protos`), extracting necessary information
|
|
78
|
+
such as operator names, arguments, and conversion types. It constructs the function definitions, includes,
|
|
79
|
+
and registration code. The generated content is saved to the specified path as a C++ source file.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
work_path (str): The file path where the generated files will be saved.
|
|
83
|
+
op_protos (list): A list of operator prototypes containing information about the operators to be processed.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
None
|
|
87
|
+
"""
|
|
88
|
+
pyboost_func_str = ''
|
|
89
|
+
pyboost_func_pybind_def = ''
|
|
90
|
+
pyboost_func_include_headers_str = ''
|
|
91
|
+
ops_inc_head_set = set()
|
|
92
|
+
for op_proto in op_protos:
|
|
93
|
+
if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
|
|
94
|
+
continue
|
|
95
|
+
op_parser = OpTemplateParser(op_proto)
|
|
96
|
+
op_pyboost_func_name = op_parser.get_pyboost_func_name()
|
|
97
|
+
op_def_name_str = op_parser.get_op_def_name_str()
|
|
98
|
+
type_num, same_type = op_parser.gen_signature_same_type_table()
|
|
99
|
+
parser_body_str = self._generate_parser_func(op_proto)
|
|
100
|
+
op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
|
|
101
|
+
convert_stub_str = self._get_convert_stub_str(op_proto)
|
|
102
|
+
optional_to_value_str = self._get_optional_to_value_str(op_proto)
|
|
103
|
+
call_args_str = self._get_call_args_str(op_proto)
|
|
104
|
+
grad_args_str = self._get_grad_args_str(op_proto)
|
|
105
|
+
cast_args_str = self._get_cast_to_value_str(op_proto)
|
|
106
|
+
view_arg_str = self._get_first_str(op_proto.op_view, grad_args_str)
|
|
107
|
+
op_input_args_str = self._get_input_args_str(op_proto)
|
|
108
|
+
view_arg_str = ", " + view_arg_str if view_arg_str else ''
|
|
109
|
+
multi_ouptut_str = 'Multi' if is_op_multi_output(op_proto.op_returns) else ''
|
|
110
|
+
output_num_str = len(op_proto.op_returns)
|
|
111
|
+
function_tpl = self._get_function_tpl(op_proto)
|
|
112
|
+
pyboost_func_str += function_tpl.replace(func_name=op_pyboost_func_name,
|
|
113
|
+
op_def_name=op_def_name_str,
|
|
114
|
+
type_num=type_num,
|
|
115
|
+
same_type=same_type,
|
|
116
|
+
input_args=op_input_args_str,
|
|
117
|
+
parser_body=parser_body_str,
|
|
118
|
+
op_name=op_proto.op_class.name,
|
|
119
|
+
class_name=op_proto.op_class.name,
|
|
120
|
+
op_args=op_args_str,
|
|
121
|
+
convert_stub=convert_stub_str,
|
|
122
|
+
optional_to_value=optional_to_value_str,
|
|
123
|
+
call_args=call_args_str,
|
|
124
|
+
grad_args=grad_args_str,
|
|
125
|
+
cast_args=cast_args_str,
|
|
126
|
+
view_arg=view_arg_str,
|
|
127
|
+
is_multi=multi_ouptut_str,
|
|
128
|
+
output_num=output_num_str,
|
|
129
|
+
operator_name=op_proto.op_name)
|
|
130
|
+
pyboost_func_str = pyboost_func_str + template.NEW_LINE + template.NEW_LINE
|
|
131
|
+
pyboost_op_name = op_parser.get_pyboost_name()
|
|
132
|
+
pyboost_func_name = op_parser.get_pyboost_func_name()
|
|
133
|
+
pyboost_func_pybind_def += self.REGISTER_DEFINE_TEMPLATE.replace(
|
|
134
|
+
pyboost_op_name=pyboost_op_name,
|
|
135
|
+
pyboost_cfunc_name=pyboost_func_name,
|
|
136
|
+
class_name=op_proto.op_class.name)
|
|
137
|
+
pyboost_func_include_headers_str += self.pyboost_func_include_header_template.replace(
|
|
138
|
+
operator_name=op_proto.op_name)
|
|
139
|
+
ops_inc_head_set.add(self.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
|
|
140
|
+
register_func_str = self.REGISTER_TEMPLATE.replace(register_func=pyboost_func_pybind_def)
|
|
141
|
+
function_class_register = self._get_function_class_register(op_protos)
|
|
142
|
+
pyboost_func_file = self.PYBOOST_HEADER_TEMPLATE.replace(ops_inc=list(sorted(ops_inc_head_set)),
|
|
143
|
+
include_op_header=pyboost_func_include_headers_str,
|
|
144
|
+
function_body=pyboost_func_str,
|
|
145
|
+
register_function_body=register_func_str,
|
|
146
|
+
function_class_register=function_class_register)
|
|
147
|
+
save_path = os.path.join(work_path, K.PIPELINE_PYBOOST_FUNC_GEN_PATH)
|
|
148
|
+
file_name = "pyboost_functions.cc"
|
|
149
|
+
save_file(save_path, file_name, pyboost_func_file)
|
|
150
|
+
|
|
151
|
+
def _get_cast_args_with_type_str(self, op_proto, cast_args_str):
|
|
152
|
+
args_with_type = []
|
|
153
|
+
for op_arg, cast_args_name in zip(op_proto.op_args, cast_args_str):
|
|
154
|
+
input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional_param(op_arg))
|
|
155
|
+
args_with_type.append("const " + input_dtype + " &" + cast_args_name)
|
|
156
|
+
return list(args_with_type)
|
|
157
|
+
|
|
158
|
+
def _get_function_class_register(self, op_protos) -> str:
|
|
159
|
+
"""
|
|
160
|
+
Generates a function class registration string for tensor functions.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
op_protos (list): A list of tensor op prototypes.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
str: A concatenated string representing the registration information for tensor
|
|
167
|
+
function classes.
|
|
168
|
+
"""
|
|
169
|
+
function_class_register = ''
|
|
170
|
+
for op_proto in op_protos:
|
|
171
|
+
if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
|
|
172
|
+
continue
|
|
173
|
+
class_name, op_name = op_proto.op_class.name, op_proto.op_name
|
|
174
|
+
function_class_register += self.TENSOR_FUNC_CLASS_REG.replace(class_name=class_name,
|
|
175
|
+
op_name=op_name)
|
|
176
|
+
return function_class_register
|
|
177
|
+
|
|
178
|
+
def _generate_parser_func(self, op_proto: OpProto) -> str:
|
|
179
|
+
"""
|
|
180
|
+
Generates the parsing function for the operator's arguments.
|
|
181
|
+
|
|
182
|
+
This method constructs the code for converting each argument in the operator prototype to its appropriate
|
|
183
|
+
type, handling optional parameters as necessary.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
str: The generated parsing function code as a string.
|
|
190
|
+
"""
|
|
191
|
+
parser_func_str = ''
|
|
192
|
+
for index, op_arg in enumerate(op_proto.op_args):
|
|
193
|
+
is_optional = is_optional_param(op_arg)
|
|
194
|
+
if op_arg.is_type_id:
|
|
195
|
+
convert_type_str = get_convert_type_str('type', is_optional)
|
|
196
|
+
else:
|
|
197
|
+
convert_type_str = get_convert_type_str(op_arg.arg_dtype, is_optional)
|
|
198
|
+
|
|
199
|
+
parser_func_str += self.convert_template.replace(arg_name=op_arg.arg_name, convert_func=convert_type_str,
|
|
200
|
+
arg_index=pyboost_utils.get_index(index))
|
|
201
|
+
return parser_func_str
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _get_input_args_str(self, op_proto: OpProto) -> str:
|
|
205
|
+
"""
|
|
206
|
+
Generates the input arguments list for the pyboost operator.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
str: The generated input arguments list as a string.
|
|
213
|
+
"""
|
|
214
|
+
parser_func_str = ''
|
|
215
|
+
for _, op_arg in enumerate(op_proto.op_args):
|
|
216
|
+
is_optional = is_optional_param(op_arg)
|
|
217
|
+
if op_arg.is_type_id:
|
|
218
|
+
arg_type_str = get_input_args_type_str('type', is_optional)
|
|
219
|
+
else:
|
|
220
|
+
arg_type_str = get_input_args_type_str(op_arg.arg_dtype, is_optional)
|
|
221
|
+
parser_func_str += self.input_args_template.replace(arg_name=op_arg.arg_name, arg_type=arg_type_str)
|
|
222
|
+
return parser_func_str[:-1]
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _get_convert_stub_str(self, op_proto: OpProto):
|
|
226
|
+
"""
|
|
227
|
+
Generates the conversion stub code for the operator's arguments.
|
|
228
|
+
|
|
229
|
+
This method creates code for converting operator arguments to tensor format, depending on whether they
|
|
230
|
+
are view operations or standard tensor operations.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
str: The generated conversion stub code as a string.
|
|
237
|
+
"""
|
|
238
|
+
convert_stub_str = ''
|
|
239
|
+
need_contiguous = 'true'
|
|
240
|
+
if op_proto.op_view:
|
|
241
|
+
# View/ACLNN op does not need to convert to contiguous tensor.
|
|
242
|
+
need_contiguous = 'false'
|
|
243
|
+
for op_arg in op_proto.op_args:
|
|
244
|
+
if pyboost_utils.is_tensor(op_arg):
|
|
245
|
+
convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
|
|
246
|
+
else op_arg.arg_name + "_tensor"
|
|
247
|
+
convert_stub_str += self.convert_to_tensor_template.replace(input=op_arg.arg_name,
|
|
248
|
+
output=convert_stub_output_name,
|
|
249
|
+
need_contiguous=need_contiguous)
|
|
250
|
+
elif pyboost_utils.is_tensor_list(op_arg):
|
|
251
|
+
# To adapt the cases where TensorList is optional.
|
|
252
|
+
convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
|
|
253
|
+
else op_arg.arg_name + "_tensor_list"
|
|
254
|
+
convert_stub_str += self.convert_to_tensor_list_template.replace(input=op_arg.arg_name,
|
|
255
|
+
output=convert_stub_output_name,
|
|
256
|
+
need_contiguous=need_contiguous)
|
|
257
|
+
return convert_stub_str
|
|
258
|
+
|
|
259
|
+
def _get_optional_to_value_str(self, op_proto: OpProto):
|
|
260
|
+
"""
|
|
261
|
+
Generates the code for converting optional arguments to their corresponding values.
|
|
262
|
+
|
|
263
|
+
This method constructs code to handle optional arguments and converts them to their actual values,
|
|
264
|
+
ensuring proper handling for tensors and lists.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
str: The generated code for converting optional arguments to values as a string.
|
|
271
|
+
"""
|
|
272
|
+
optional_to_value_str = ''
|
|
273
|
+
for op_arg in op_proto.op_args:
|
|
274
|
+
if is_optional_param(op_arg):
|
|
275
|
+
if pyboost_utils.is_tensor(op_arg) or pyboost_utils.is_tensor_list(op_arg):
|
|
276
|
+
convert_stub_output_name = op_arg.arg_name + '_optional'
|
|
277
|
+
cast_output = 'cast_' + convert_stub_output_name
|
|
278
|
+
convert_optional_to_value_name = op_arg.arg_name + '_value'
|
|
279
|
+
optional_to_value_str += \
|
|
280
|
+
self.convert_optional_to_value_template.replace(input=cast_output,
|
|
281
|
+
output=convert_optional_to_value_name)
|
|
282
|
+
else:
|
|
283
|
+
call_arg = op_arg.arg_name
|
|
284
|
+
convert_optional_to_value_name = op_arg.arg_name + '_value'
|
|
285
|
+
optional_to_value_str += \
|
|
286
|
+
self.convert_optional_to_value_template.replace(input=call_arg,
|
|
287
|
+
output=convert_optional_to_value_name)
|
|
288
|
+
return optional_to_value_str
|
|
289
|
+
|
|
290
|
+
def _get_call_args_str(self, op_proto: OpProto):
|
|
291
|
+
"""
|
|
292
|
+
Generates the list of call arguments for the operator.
|
|
293
|
+
|
|
294
|
+
This method constructs a list of argument names for the function call, adapting the names for
|
|
295
|
+
optional tensors and tensor lists as needed.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
list: A list of formatted argument names for the function call.
|
|
302
|
+
"""
|
|
303
|
+
call_args_str = []
|
|
304
|
+
for op_arg in op_proto.op_args:
|
|
305
|
+
if pyboost_utils.is_tensor(op_arg):
|
|
306
|
+
convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
|
|
307
|
+
else op_arg.arg_name + "_tensor"
|
|
308
|
+
call_arg = convert_stub_output_name
|
|
309
|
+
elif pyboost_utils.is_tensor_list(op_arg):
|
|
310
|
+
convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
|
|
311
|
+
else op_arg.arg_name + "_tensor_list"
|
|
312
|
+
call_arg = convert_stub_output_name
|
|
313
|
+
else:
|
|
314
|
+
call_arg = op_arg.arg_name
|
|
315
|
+
call_args_str.append(call_arg)
|
|
316
|
+
return call_args_str
|
|
317
|
+
|
|
318
|
+
def _get_grad_args_str(self, op_proto: OpProto):
|
|
319
|
+
"""
|
|
320
|
+
Generates the list of gradient arguments for the operator.
|
|
321
|
+
|
|
322
|
+
This method constructs a list of argument names used for computing gradients, adapting for
|
|
323
|
+
optional tensors and tensor lists as necessary.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
list: A list of formatted gradient argument names.
|
|
330
|
+
"""
|
|
331
|
+
grad_args_str = []
|
|
332
|
+
for op_arg in op_proto.op_args:
|
|
333
|
+
if pyboost_utils.is_tensor(op_arg):
|
|
334
|
+
grad_arg = op_arg.arg_name + "_value" if is_optional_param(op_arg) else \
|
|
335
|
+
f"cast_" + op_arg.arg_name + "_tensor"
|
|
336
|
+
elif pyboost_utils.is_tensor_list(op_arg):
|
|
337
|
+
if is_optional_param(op_arg):
|
|
338
|
+
# To adapt the cases where TensorList is optional.
|
|
339
|
+
convert_optional_to_value_name = op_arg.arg_name + "_value"
|
|
340
|
+
grad_arg = convert_optional_to_value_name
|
|
341
|
+
else:
|
|
342
|
+
convert_stub_output_name = op_arg.arg_name + "_tensor_list"
|
|
343
|
+
grad_arg = "cast_" + convert_stub_output_name
|
|
344
|
+
else:
|
|
345
|
+
grad_arg = "cast_" + op_arg.arg_name
|
|
346
|
+
if is_optional_param(op_arg):
|
|
347
|
+
convert_optional_to_value_name = op_arg.arg_name + "_value"
|
|
348
|
+
grad_arg = convert_optional_to_value_name
|
|
349
|
+
grad_args_str.append(grad_arg)
|
|
350
|
+
return grad_args_str
|
|
351
|
+
|
|
352
|
+
def _get_cast_to_value_str(self, op_proto: OpProto):
|
|
353
|
+
"""
|
|
354
|
+
Generates the list of cast arguments for the operator.
|
|
355
|
+
|
|
356
|
+
This method constructs a list of argument names that need to be cast to their corresponding types.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
list: A list of formatted cast argument names.
|
|
363
|
+
"""
|
|
364
|
+
cast_args_str = []
|
|
365
|
+
for op_arg in op_proto.op_args:
|
|
366
|
+
cast_str = 'cast_'
|
|
367
|
+
if pyboost_utils.is_tensor(op_arg):
|
|
368
|
+
convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
|
|
369
|
+
else op_arg.arg_name + "_tensor"
|
|
370
|
+
cast_arg = cast_str + convert_stub_output_name
|
|
371
|
+
elif pyboost_utils.is_tensor_list(op_arg):
|
|
372
|
+
# To adapt the cases where TensorList is optional.
|
|
373
|
+
convert_stub_output_name = op_arg.arg_name + '_optional' if is_optional_param(op_arg) \
|
|
374
|
+
else op_arg.arg_name + "_tensor_list"
|
|
375
|
+
cast_arg = cast_str + convert_stub_output_name
|
|
376
|
+
else:
|
|
377
|
+
cast_arg = cast_str + op_arg.arg_name
|
|
378
|
+
cast_args_str.append(cast_arg)
|
|
379
|
+
return cast_args_str
|
|
380
|
+
|
|
381
|
+
def _get_first_str(self, is_view_or_inplace: bool, grad_args: list):
|
|
382
|
+
"""
|
|
383
|
+
Generates the view base str of arguments for the operator.
|
|
384
|
+
|
|
385
|
+
This method constructs a list of argument names that need to be cast to their corresponding types.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
is_view_or_inplace (bool): Whether the op is view op or inplace op.
|
|
389
|
+
grad_args (list): grad args
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
str: Formatted view or inplace first argument names.
|
|
393
|
+
"""
|
|
394
|
+
arg_str = ''
|
|
395
|
+
for i, grad_arg in enumerate(grad_args):
|
|
396
|
+
if is_view_or_inplace and i == 0:
|
|
397
|
+
arg_str = grad_arg
|
|
398
|
+
break
|
|
399
|
+
return arg_str
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def _get_function_tpl(self, op_proto: OpProto):
|
|
403
|
+
if len(op_proto.op_returns) == 1 and is_tensor_list(op_proto.op_returns[0]):
|
|
404
|
+
# op output size is unknown
|
|
405
|
+
return self.PYBOOST_FUNCTION_DYNAMIC_OUTPUT_TEMPLATE
|
|
406
|
+
return self.PYBOOST_COMM_FUNCTION_TEMPLATE \
|
|
407
|
+
if op_proto.op_dispatch.is_comm_op else self.PYBOOST_FUNCTION_TEMPLATE
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module defines the `PyboostFunctionsHeaderGenerator` class, which is responsible for generating
|
|
17
|
+
the header file (`pyboost_functions.h`) for Pyboost function declarations.
|
|
18
|
+
|
|
19
|
+
The class uses templates and operation prototypes to create function declarations based on the
|
|
20
|
+
operation's primitive and arguments. The generated file is saved to the specified path.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
|
|
25
|
+
import common.template as template
|
|
26
|
+
import common.gen_constants as K
|
|
27
|
+
from common.template import Template
|
|
28
|
+
from common.gen_utils import save_file
|
|
29
|
+
from common.base_generator import BaseGenerator
|
|
30
|
+
from common.op_proto import OpProto
|
|
31
|
+
|
|
32
|
+
from .op_template_parser import OpTemplateParser
|
|
33
|
+
from .pyboost_utils import is_optional_param, get_input_args_type_str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PyboostFunctionsHeaderGenerator(BaseGenerator):
|
|
37
|
+
"""
|
|
38
|
+
A class to generate the `pyboost_functions.h` header file, which contains Pyboost function declarations.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
"""Initializes the PyboostFunctionsHeaderGenerator with the necessary templates."""
|
|
43
|
+
self.PYBOOST_FUNCTION_HEADER_TEMPLATE = template.PYBOOST_FUNCTION_HEADER_TEMPLATE
|
|
44
|
+
|
|
45
|
+
self.pyboost_func_template = Template(
|
|
46
|
+
'py::object PYNATIVE_EXPORT ${func_name}_Base(const PrimitivePtr &prim, const py::list &args);'
|
|
47
|
+
)
|
|
48
|
+
self.pyboost_op_func_template = Template(
|
|
49
|
+
'py::object ME_EXPORT ${func_name}_OP(const PrimitivePtr &prim, '
|
|
50
|
+
'const std::vector<ops::OP_DTYPE>& source_type, ${input_args});'
|
|
51
|
+
)
|
|
52
|
+
self.input_args_template = Template(" const ${arg_type}& ${arg_name},")
|
|
53
|
+
|
|
54
|
+
def generate(self, work_path, op_protos):
|
|
55
|
+
"""
|
|
56
|
+
Generates the Pyboost function header file (`pyboost_functions.h`).
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
work_path (str): The directory where the generated file will be saved.
|
|
60
|
+
op_protos (list): A list of operation prototypes to parse and convert into Pyboost function declarations.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
None: The method writes the generated header file to the specified directory.
|
|
64
|
+
"""
|
|
65
|
+
prim_func_list = []
|
|
66
|
+
op_func_list_str = []
|
|
67
|
+
for op_proto in op_protos:
|
|
68
|
+
if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
|
|
69
|
+
continue
|
|
70
|
+
op_parser = OpTemplateParser(op_proto)
|
|
71
|
+
op_pyboost_func_name = op_parser.get_pyboost_func_name()
|
|
72
|
+
op_input_args_str = self._get_input_args_str(op_proto)
|
|
73
|
+
prim_func_list.append(self.pyboost_func_template.replace(func_name=op_pyboost_func_name))
|
|
74
|
+
op_func_list_str.append(self.pyboost_op_func_template.replace(func_name=op_pyboost_func_name,
|
|
75
|
+
input_args=op_input_args_str))
|
|
76
|
+
pyboost_func_h_str = self.PYBOOST_FUNCTION_HEADER_TEMPLATE.replace(prim_func_list=prim_func_list,
|
|
77
|
+
op_func_list=op_func_list_str)
|
|
78
|
+
save_path = os.path.join(work_path, K.PIPELINE_PYBOOST_FUNC_GEN_PATH)
|
|
79
|
+
file_name = "pyboost_functions.h"
|
|
80
|
+
save_file(save_path, file_name, pyboost_func_h_str)
|
|
81
|
+
|
|
82
|
+
def _get_input_args_str(self, op_proto: OpProto) -> str:
|
|
83
|
+
"""
|
|
84
|
+
Generates the input arguments list for the pyboost operator.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
op_proto (OpProto): The operator prototype containing the argument information.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
str: The generated input arguments list as a string.
|
|
91
|
+
"""
|
|
92
|
+
parser_func_str = ''
|
|
93
|
+
for _, op_arg in enumerate(op_proto.op_args):
|
|
94
|
+
is_optional = is_optional_param(op_arg)
|
|
95
|
+
if op_arg.is_type_id:
|
|
96
|
+
arg_type_str = get_input_args_type_str('type', is_optional)
|
|
97
|
+
else:
|
|
98
|
+
arg_type_str = get_input_args_type_str(op_arg.arg_dtype, is_optional)
|
|
99
|
+
parser_func_str += self.input_args_template.replace(arg_name=op_arg.arg_name, arg_type=arg_type_str)
|
|
100
|
+
return parser_func_str[:-1]
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module defines the PyboostFunctionsPyGenerator class for generating Python bindings for PyBoost functions.
|
|
17
|
+
|
|
18
|
+
The PyboostFunctionsPyGenerator class processes operator prototypes and generates Python functions
|
|
19
|
+
that correspond to the PyBoost operations defined in the operator prototypes. It handles the necessary
|
|
20
|
+
argument processing and includes appropriate documentation descriptions.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
|
|
25
|
+
import common.template as template
|
|
26
|
+
import common.gen_constants as K
|
|
27
|
+
from common.gen_utils import save_file
|
|
28
|
+
from common.op_proto import OpProto
|
|
29
|
+
from common.base_generator import BaseGenerator
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PyboostFunctionsPyGenerator(BaseGenerator):
|
|
33
|
+
"""
|
|
34
|
+
Generates Python bindings for PyBoost functions.
|
|
35
|
+
|
|
36
|
+
This class is responsible for creating Python function definitions that correspond to the PyBoost
|
|
37
|
+
operations defined in operator prototypes. It generates a Python file that includes necessary function
|
|
38
|
+
definitions and their descriptions.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
"""Initializes the PyboostFunctionsPyGenerator with required templates."""
|
|
43
|
+
self.IMPORT_PYBOOST_FUNC_HEADER = template.IMPORT_PYBOOST_FUNC_HEADER
|
|
44
|
+
self.PYBOOST_PY_FUNC_TEMPLATE = template.PYBOOST_PY_FUNC_TEMPLATE
|
|
45
|
+
|
|
46
|
+
def generate(self, work_path, op_protos, doc_data):
|
|
47
|
+
"""
|
|
48
|
+
Generates the Python file containing PyBoost function definitions.
|
|
49
|
+
|
|
50
|
+
This method processes the provided operator prototypes (`op_protos`), generates Python function
|
|
51
|
+
definitions for each operator that meets the specified conditions, and saves the generated content
|
|
52
|
+
to a Python file.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
work_path (str): The directory path where the generated Python file will be saved.
|
|
56
|
+
op_protos (list): A list of operator prototypes containing information about the operators.
|
|
57
|
+
doc_data (dict): A dictionary containing documentation data for the operators.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
None
|
|
61
|
+
"""
|
|
62
|
+
gen_py = ''
|
|
63
|
+
op_desc_dict = self._get_op_description_dict(doc_data)
|
|
64
|
+
for op_proto in op_protos:
|
|
65
|
+
# check if the operator is in pyboost scenario
|
|
66
|
+
if op_proto.op_dispatch is None or not op_proto.op_dispatch.enable:
|
|
67
|
+
continue
|
|
68
|
+
if op_proto.op_function.disable:
|
|
69
|
+
continue
|
|
70
|
+
if not op_proto.op_function.name.endswith("_ext") and not op_proto.op_name.endswith("_ext"):
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
description = op_desc_dict.get(op_proto.op_name)
|
|
74
|
+
func_args, input_args = self._process_args(op_proto.op_args)
|
|
75
|
+
func_name, func_impl_name = self._get_func_impl_name(op_proto)
|
|
76
|
+
gen_py += self.PYBOOST_PY_FUNC_TEMPLATE.replace(func_name=func_name,
|
|
77
|
+
description=description,
|
|
78
|
+
func_args=func_args,
|
|
79
|
+
input_args=input_args,
|
|
80
|
+
func_impl_name=func_impl_name)
|
|
81
|
+
py_header = template.PY_LICENSE_STR + self.IMPORT_PYBOOST_FUNC_HEADER
|
|
82
|
+
save_file(os.path.join(work_path, K.PY_AUTO_GEN_PATH), "gen_extend_func.py", py_header + gen_py)
|
|
83
|
+
|
|
84
|
+
def _get_op_description_dict(self, doc_yaml_data):
|
|
85
|
+
"""
|
|
86
|
+
Constructs a dictionary mapping operator names to their descriptions.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
doc_yaml_data (dict): A dictionary containing YAML data for operator documentation.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
dict: A dictionary mapping operator names to their descriptions.
|
|
93
|
+
"""
|
|
94
|
+
op_description_dict = {}
|
|
95
|
+
for operator_name, operator_desc in doc_yaml_data.items():
|
|
96
|
+
desc = operator_desc.get("description")
|
|
97
|
+
op_description_dict[operator_name] = desc
|
|
98
|
+
return op_description_dict
|
|
99
|
+
|
|
100
|
+
def _process_args(self, op_args):
|
|
101
|
+
"""
|
|
102
|
+
Processes the operator arguments to generate function argument strings.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
op_args (list): A list of operator arguments to be processed.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
tuple: A tuple containing:
|
|
109
|
+
- func_args (list): A list of formatted function argument strings.
|
|
110
|
+
- input_args (list): A list of corresponding input argument names.
|
|
111
|
+
"""
|
|
112
|
+
func_args = []
|
|
113
|
+
input_args = []
|
|
114
|
+
for op_arg in op_args:
|
|
115
|
+
arg_handler = op_arg.arg_handler
|
|
116
|
+
arg_name = op_arg.arg_name
|
|
117
|
+
input_arg = arg_name
|
|
118
|
+
if arg_handler != '' and arg_handler != 'dtype_to_type_id':
|
|
119
|
+
input_arg = 'converted_' + arg_name
|
|
120
|
+
input_args.append(input_arg)
|
|
121
|
+
default_value = op_arg.default
|
|
122
|
+
if default_value is not None:
|
|
123
|
+
default_value = '=' + str(default_value)
|
|
124
|
+
func_args.append(arg_name + default_value)
|
|
125
|
+
else:
|
|
126
|
+
func_args.append(arg_name)
|
|
127
|
+
return func_args, input_args
|
|
128
|
+
|
|
129
|
+
def _get_func_impl_name(self, op_proto: OpProto):
|
|
130
|
+
"""
|
|
131
|
+
Retrieves the implementation function name based on the operator prototype.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
op_proto (OpProto): The operator prototype containing function name information.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
tuple: A tuple containing:
|
|
138
|
+
- func_name (str): The name of the function.
|
|
139
|
+
- func_impl_name (str): The implementation name of the function.
|
|
140
|
+
"""
|
|
141
|
+
func_name = op_proto.op_name if op_proto.op_function.name == '' \
|
|
142
|
+
else op_proto.op_function.name
|
|
143
|
+
if func_name.endswith("_ext"):
|
|
144
|
+
func_name = func_name[:-4]
|
|
145
|
+
func_impl_name = func_name
|
|
146
|
+
if func_name.endswith("_"):
|
|
147
|
+
func_impl_name = func_name[:-1]
|
|
148
|
+
return func_name, func_impl_name
|