mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module defines the PyboostFunctionsGenerator class for generating C++ functions for PyBoost operations.
|
|
17
|
+
|
|
18
|
+
The generator processes operator prototypes and constructs the necessary function definitions, including
|
|
19
|
+
conversions for optional parameters and tensor arguments. It generates the registration code and includes
|
|
20
|
+
the necessary header files for the generated functions.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
|
|
25
|
+
import common.template as template
|
|
26
|
+
import common.gen_constants as K
|
|
27
|
+
from common.template import Template
|
|
28
|
+
from common.gen_utils import save_file
|
|
29
|
+
from common.op_proto import OpProto
|
|
30
|
+
from common.base_generator import BaseGenerator
|
|
31
|
+
from pyboost import pyboost_utils
|
|
32
|
+
import api.op_api_proto as op_api_proto
|
|
33
|
+
|
|
34
|
+
from .op_template_parser import OpTemplateParser
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class PyboostOverloadFunctionsGenerator(BaseGenerator):
|
|
38
|
+
"""
|
|
39
|
+
Generates PyBoost overload functions cpp code based on operator prototypes.
|
|
40
|
+
|
|
41
|
+
This class processes operator prototypes (`op_protos`) to create the necessary C++ function definitions for
|
|
42
|
+
PyBoost operations. It constructs function bodies, handles optional value conversions, and generates
|
|
43
|
+
registration code and header inclusions.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self):
|
|
47
|
+
self.PYBOOST_OVERLOAD_FUNCTIONS_TEMPLATE = template.PYBOOST_OVERLOAD_FUNCTIONS_CC_TEMPLATE
|
|
48
|
+
self.PYBOOST_MINT_CLASS_DEF = template.PYBOOST_MINT_CLASS_DEF
|
|
49
|
+
self.PYBOOST_OVERLOAD_MINT_CLASS_DEF = template.PYBOOST_OVERLOAD_MINT_CLASS_DEF
|
|
50
|
+
self.TENSOR_FUNC_UT_BODY = template.TENSOR_FUNC_UT_BODY
|
|
51
|
+
self.TENSOR_FUNC_UT_OVERLOAD_BODY = template.TENSOR_FUNC_UT_OVERLOAD_BODY
|
|
52
|
+
|
|
53
|
+
self.single_case_template = Template(
|
|
54
|
+
'case ${case_id}:\n'
|
|
55
|
+
' ${device_dispatcher}\n'
|
|
56
|
+
' break;\n'
|
|
57
|
+
)
|
|
58
|
+
self.device_dispatcher_template = Template(
|
|
59
|
+
'if (backend == kAscendDevice || backend == kDavinciDevice) {\n'
|
|
60
|
+
' ${ascend_dispatcher}\n'
|
|
61
|
+
'} else if (backend == kCPUDevice) {\n'
|
|
62
|
+
' ${cpu_dispatcher}\n'
|
|
63
|
+
'} else if (backend == kGPUDevice) {\n'
|
|
64
|
+
' ${gpu_dispatcher}\n'
|
|
65
|
+
'} else {\n'
|
|
66
|
+
' MS_LOG(ERROR) << "Device target is not supported!";\n'
|
|
67
|
+
' return py::none();\n'
|
|
68
|
+
'}'
|
|
69
|
+
)
|
|
70
|
+
self.pyboost_return_template = Template(
|
|
71
|
+
'${arg_handler_processor}\n'
|
|
72
|
+
'MS_LOG(INFO) << "Call Tensor${class_name}";\n'
|
|
73
|
+
'auto res = ${pyboost_base_func_name}_OP(${prim_name}, parse_args.src_types_, ${convert_args});\n'
|
|
74
|
+
'trace::Capture(parse_args.arg_list_, "${class_name}", &res);\n'
|
|
75
|
+
'return res;\n'
|
|
76
|
+
)
|
|
77
|
+
self.callback_python_template = Template(
|
|
78
|
+
'MS_LOG(INFO) << "Callback python method: ${py_method}";\n'
|
|
79
|
+
'py::function fn = python_adapter::GetPyFn(\"mindspore.ops.tensor_method\", \"${py_method}\");\n'
|
|
80
|
+
'py::object res = fn(*args, **kwargs);\n'
|
|
81
|
+
'return res;\n'
|
|
82
|
+
)
|
|
83
|
+
self.pybind_register_template = Template(
|
|
84
|
+
'(void)py::class_<${cpp_func_name}Functional, Functional, std::shared_ptr<${cpp_func_name}Functional>>\n'
|
|
85
|
+
' (*m, "${cpp_func_name}Functional_")\n'
|
|
86
|
+
' .def("__call__", &${cpp_func_name}Functional::Call, "Call ${cpp_func_name} functional.");\n'
|
|
87
|
+
'm->attr("_${mint_func_name}_instance") = ${mint_func_name}_instance;'
|
|
88
|
+
)
|
|
89
|
+
self.callback_python_in_ut_template = Template(
|
|
90
|
+
'MS_LOG(INFO) << "Callback python method in UT: ${py_method}";\n'
|
|
91
|
+
'fn = python_adapter::GetPyFn(\"mindspore.ops.tensor_method\", \"${py_method}\");\n'
|
|
92
|
+
'res = fn(*args, **kwargs);\n'
|
|
93
|
+
'break;\n'
|
|
94
|
+
)
|
|
95
|
+
self.single_case_in_ut_template = Template(
|
|
96
|
+
'case ${case_id}:\n'
|
|
97
|
+
' ${device_dispatcher}\n'
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def generate(self, work_path, op_protos, mint_func_protos_data, alias_func_mapping):
|
|
101
|
+
"""
|
|
102
|
+
Generates the C++ PyBoost functions and writes them to the specified files.
|
|
103
|
+
|
|
104
|
+
This method processes a list of operator prototypes (`op_protos`), extracting necessary information
|
|
105
|
+
such as operator names, arguments, and conversion types. It constructs the function definitions, includes,
|
|
106
|
+
and registration code. The generated content is saved to the specified path as a C++ source file.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
work_path (str): The file path where the generated files will be saved.
|
|
110
|
+
op_protos (list): A list of operator prototypes containing information about the operators to be processed.
|
|
111
|
+
mint_func_protos_data (dict): A dict of tensor prototypes containing device-related information.
|
|
112
|
+
alias_func_mapping (dict): A dict mapping from api name to its alias api name.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
None
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
mint_classes_def_list = []
|
|
119
|
+
ops_inc_head_set = set()
|
|
120
|
+
_, single_mint_func_data, overload_mint_func_data, op_class_name_set = op_api_proto.categorize_func_data(
|
|
121
|
+
mint_func_protos_data)
|
|
122
|
+
single_func_call_body_list, single_cpp_class_name_list = (
|
|
123
|
+
self._get_single_func_call_body_list(single_mint_func_data))
|
|
124
|
+
overload_func_call_body_list, overload_cpp_class_name_list = (
|
|
125
|
+
self._get_overload_func_call_body_list(overload_mint_func_data))
|
|
126
|
+
|
|
127
|
+
mint_classes_def_list.extend(single_func_call_body_list)
|
|
128
|
+
mint_classes_def_list.extend(overload_func_call_body_list)
|
|
129
|
+
|
|
130
|
+
cpp_class_name_list = single_cpp_class_name_list + overload_cpp_class_name_list
|
|
131
|
+
mint_classes_reg_list = (
|
|
132
|
+
self._get_mint_func_reg_list(single_mint_func_data, overload_mint_func_data, cpp_class_name_list))
|
|
133
|
+
for op_class_name in op_class_name_set:
|
|
134
|
+
ops_inc_head_set.add(template.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_class_name[0].lower()))
|
|
135
|
+
pyboost_overload_file_str = (
|
|
136
|
+
self.PYBOOST_OVERLOAD_FUNCTIONS_TEMPLATE.replace(ops_inc=list(sorted(ops_inc_head_set)),
|
|
137
|
+
mint_func_classes_def=mint_classes_def_list,
|
|
138
|
+
pybind_register_code=mint_classes_reg_list))
|
|
139
|
+
save_path = os.path.join(work_path, K.PIPELINE_PYBOOST_FUNC_GEN_PATH)
|
|
140
|
+
file_name = "pyboost_overload_functions.cc"
|
|
141
|
+
save_file(save_path, file_name, pyboost_overload_file_str)
|
|
142
|
+
|
|
143
|
+
def _get_single_func_call_body_list(self, single_op_func_data):
|
|
144
|
+
"""
|
|
145
|
+
Generates the list of call body strings for single operation functions.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
single_op_func_data (dict): Dictionary of tensor function prototypes with only one definition.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
func_call_body_list (list): Updated str list for generating C++ function call bodies.
|
|
152
|
+
cpp_class_name_list (list): The list of non-overloaded c++ functional classes' names.
|
|
153
|
+
"""
|
|
154
|
+
func_call_body_list, cpp_class_name_list = [], []
|
|
155
|
+
for _, func_proto in single_op_func_data.items():
|
|
156
|
+
func_name = func_proto.func_name
|
|
157
|
+
class_name = func_proto.op_proto.op_class.name
|
|
158
|
+
device_dispatcher_str = self._get_device_dispatchers_str(func_proto)
|
|
159
|
+
signature_str = self._generate_single_signature_str(
|
|
160
|
+
func_proto.op_proto, func_proto.kw_only_args, func_proto.varargs)
|
|
161
|
+
op_args = func_proto.op_proto.op_args
|
|
162
|
+
max_size = len(op_args)
|
|
163
|
+
ut_body = self.TENSOR_FUNC_UT_BODY.replace(py_method=func_proto.py_method)
|
|
164
|
+
func_call_body_list.append(self.PYBOOST_MINT_CLASS_DEF.replace(
|
|
165
|
+
class_name=class_name,
|
|
166
|
+
func_name=func_name,
|
|
167
|
+
device_dispatcher=device_dispatcher_str,
|
|
168
|
+
signatures=signature_str,
|
|
169
|
+
max_args=max_size,
|
|
170
|
+
ut_body=ut_body))
|
|
171
|
+
cpp_class_name_list.append(class_name)
|
|
172
|
+
return func_call_body_list, cpp_class_name_list
|
|
173
|
+
|
|
174
|
+
def _get_overload_func_call_body_list(self, overload_op_func_data):
|
|
175
|
+
"""
|
|
176
|
+
Generates the list of call body strings for overloaded operation functions.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
overload_op_func_data (dict): Dictionary of tensor function prototypes with overloaded definitions.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
func_call_body_list (list): Updated str list for generating C++ function call bodies.
|
|
183
|
+
cpp_class_name_list (list): The list of overloaded c++ functional classes' names.
|
|
184
|
+
"""
|
|
185
|
+
func_call_body_list, cpp_class_name_list = [], []
|
|
186
|
+
for func_api_name, func_protos in overload_op_func_data.items():
|
|
187
|
+
func_call_body_list.append(
|
|
188
|
+
self._get_overload_func_call_str(func_api_name, func_protos, cpp_class_name_list))
|
|
189
|
+
return func_call_body_list, cpp_class_name_list
|
|
190
|
+
|
|
191
|
+
def _get_overload_func_call_str(self, func_api_name, func_protos, cpp_class_name_list):
|
|
192
|
+
"""
|
|
193
|
+
Generates C++ call body string for overloaded tensor functions.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
func_api_name (str): Name of the function API.
|
|
197
|
+
func_protos (list): List of TensorFuncProto objects representing the function prototypes.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
str: Generated call body string for the overloaded functions.
|
|
201
|
+
"""
|
|
202
|
+
signatures_str = self._generate_func_signatures_str(func_protos)
|
|
203
|
+
dispatch_cases = self._get_dispatch_cases(func_protos)
|
|
204
|
+
ut_dispatch_cases = self._get_ut_dispatch_cases(func_protos)
|
|
205
|
+
ut_overload_body = self.TENSOR_FUNC_UT_OVERLOAD_BODY.replace(ut_dispatch_cases=ut_dispatch_cases)
|
|
206
|
+
|
|
207
|
+
max_size = 0
|
|
208
|
+
for tensor_proto in func_protos:
|
|
209
|
+
op_proto = tensor_proto.op_proto
|
|
210
|
+
op_args = op_proto.op_args
|
|
211
|
+
max_size = max(len(op_args), max_size)
|
|
212
|
+
cpp_func_name = pyboost_utils.format_func_api_name(func_api_name)
|
|
213
|
+
cpp_class_name_list.append(cpp_func_name)
|
|
214
|
+
overload_func_call_str = self.PYBOOST_OVERLOAD_MINT_CLASS_DEF.replace(cpp_func_name=cpp_func_name,
|
|
215
|
+
func_name=func_api_name,
|
|
216
|
+
signatures=signatures_str,
|
|
217
|
+
dispatch_cases=dispatch_cases,
|
|
218
|
+
max_args=max_size,
|
|
219
|
+
ut_overload_body=ut_overload_body)
|
|
220
|
+
return overload_func_call_str
|
|
221
|
+
|
|
222
|
+
def _generate_func_signatures_str(self, func_protos) -> str:
|
|
223
|
+
"""
|
|
224
|
+
Generates function signatures as a string from the given prototypes.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
func_protos (list): List of TensorFuncProto objects representing the function prototypes.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
str: Generated function signatures string.
|
|
231
|
+
"""
|
|
232
|
+
sig_str = ''
|
|
233
|
+
first_sig = True
|
|
234
|
+
for tensor_proto in func_protos:
|
|
235
|
+
op_proto = tensor_proto.op_proto
|
|
236
|
+
if not first_sig:
|
|
237
|
+
sig_str += ',\n'
|
|
238
|
+
first_sig = False
|
|
239
|
+
sig_str += self._generate_single_signature_str(op_proto, tensor_proto.kw_only_args, tensor_proto.varargs)
|
|
240
|
+
return sig_str
|
|
241
|
+
|
|
242
|
+
def _generate_single_signature_str(self, op_proto: OpProto, kw_only_args, varargs) -> str:
|
|
243
|
+
op_parser = OpTemplateParser(op_proto)
|
|
244
|
+
return op_parser.generate_signature_str(kw_only_args, varargs, is_tensor_api=False)
|
|
245
|
+
|
|
246
|
+
def _get_dispatch_cases(self, func_protos):
|
|
247
|
+
"""
|
|
248
|
+
Generates C++ switch-case statements for dispatching tensor function calls.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
func_protos (list): List of TensorFuncProto objects representing the function prototypes.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
str: Generated switch-case dispatch statements.
|
|
255
|
+
"""
|
|
256
|
+
dispatch_cases_str = ''
|
|
257
|
+
for idx, func_proto in enumerate(func_protos):
|
|
258
|
+
device_dispatcher_str = self._get_device_dispatchers_str(func_proto)
|
|
259
|
+
dispatch_cases_str += self.single_case_template.replace(case_id=idx,
|
|
260
|
+
device_dispatcher=device_dispatcher_str)
|
|
261
|
+
dispatch_cases_str += 'default:\n'
|
|
262
|
+
dispatch_cases_str += ' return py::none();'
|
|
263
|
+
return dispatch_cases_str
|
|
264
|
+
|
|
265
|
+
def _get_ut_dispatch_cases(self, func_protos):
|
|
266
|
+
"""
|
|
267
|
+
Generates C++ switch-case statements for dispatching tensor function calls.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
func_protos (list): List of TensorFuncProto objects representing the function prototypes.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
str: Generated switch-case dispatch statements.
|
|
274
|
+
"""
|
|
275
|
+
dispatch_cases_str = ''
|
|
276
|
+
for idx, func_proto in enumerate(func_protos):
|
|
277
|
+
device_dispatcher_str = self.callback_python_in_ut_template.replace(py_method=func_proto.py_method)
|
|
278
|
+
dispatch_cases_str += self.single_case_in_ut_template.replace(case_id=idx,
|
|
279
|
+
device_dispatcher=device_dispatcher_str)
|
|
280
|
+
dispatch_cases_str += 'default:\n'
|
|
281
|
+
dispatch_cases_str += ' res = py::none();'
|
|
282
|
+
return dispatch_cases_str
|
|
283
|
+
|
|
284
|
+
def _get_device_dispatchers_str(self, func_proto):
|
|
285
|
+
"""
|
|
286
|
+
Generates device-specific dispatch strings for the given function prototype.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
func_proto (TensorFuncProto): Function prototype to generate dispatch strings for.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
str: Generated device-specific dispatch string.
|
|
293
|
+
"""
|
|
294
|
+
ascend_dispatcher_str = self._get_single_device_dispatcher_str(func_proto, 'ascend')
|
|
295
|
+
cpu_dispatcher_str = self._get_single_device_dispatcher_str(func_proto, 'cpu')
|
|
296
|
+
gpu_dispatcher_str = self._get_single_device_dispatcher_str(func_proto, 'gpu')
|
|
297
|
+
device_dispatcher_str = self.device_dispatcher_template.replace(ascend_dispatcher=ascend_dispatcher_str,
|
|
298
|
+
cpu_dispatcher=cpu_dispatcher_str,
|
|
299
|
+
gpu_dispatcher=gpu_dispatcher_str)
|
|
300
|
+
return device_dispatcher_str
|
|
301
|
+
|
|
302
|
+
def _get_single_device_dispatcher_str(self, func_proto, device):
|
|
303
|
+
"""
|
|
304
|
+
Generates the dispatch string for a specific device.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
func_proto (TensorFuncProto): Function prototype to generate the dispatcher for.
|
|
308
|
+
device (str): Device type ('ascend', 'cpu', 'gpu').
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
str: Generated device dispatcher string.
|
|
312
|
+
"""
|
|
313
|
+
func_proto_device = getattr(func_proto, device)
|
|
314
|
+
if func_proto_device == 'pyboost':
|
|
315
|
+
arg_handler_processor_str = self._get_arg_handler_processor(func_proto.func_name, func_proto.op_proto)
|
|
316
|
+
convert_args_str = self._get_convert_args_str(func_proto.op_proto)
|
|
317
|
+
op_parser = OpTemplateParser(func_proto.op_proto)
|
|
318
|
+
op_pyboost_func_name = op_parser.get_pyboost_func_name()
|
|
319
|
+
prim_name = f"prim::kPrim{func_proto.op_proto.op_class.name}"
|
|
320
|
+
return self.pyboost_return_template.replace(arg_handler_processor=arg_handler_processor_str,
|
|
321
|
+
class_name=func_proto.op_proto.op_class.name,
|
|
322
|
+
prim_name=prim_name,
|
|
323
|
+
pyboost_base_func_name=op_pyboost_func_name,
|
|
324
|
+
convert_args=convert_args_str)
|
|
325
|
+
if func_proto_device == 'py_method':
|
|
326
|
+
return self.callback_python_template.replace(py_method=func_proto.py_method)
|
|
327
|
+
|
|
328
|
+
raise TypeError("Only support pyboost or python_method.")
|
|
329
|
+
|
|
330
|
+
def _get_arg_handler_processor(self, func_name, op_proto):
|
|
331
|
+
op_parser = OpTemplateParser(op_proto)
|
|
332
|
+
return op_parser.get_arg_handler_processor(func_name, op_proto, is_tensor_api=False)
|
|
333
|
+
|
|
334
|
+
def _get_convert_args_str(self, op_proto):
|
|
335
|
+
op_parser = OpTemplateParser(op_proto)
|
|
336
|
+
return op_parser.get_convert_args_str(op_proto, is_tensor_api=False)
|
|
337
|
+
|
|
338
|
+
def _get_mint_func_reg_list(self, single_mint_func_data, overload_mint_func_data, cpp_class_names):
|
|
339
|
+
"""
|
|
340
|
+
Generates the list of pybind definition strings for mint functions.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
single_mint_func_data (dict): Dictionary of single mint function data.
|
|
344
|
+
overload_mint_func_data (dict): Dictionary of overload mint function data.
|
|
345
|
+
cpp_class_names (list): List of C++ class names.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
list: list of strs for generating pybind definitions of mint functions' API.
|
|
349
|
+
"""
|
|
350
|
+
# the order of single_mint_func_data/overload_mint_func_data matters
|
|
351
|
+
mint_func_names = list(single_mint_func_data.keys()) + list(overload_mint_func_data.keys())
|
|
352
|
+
|
|
353
|
+
mint_func_reg_list = []
|
|
354
|
+
for mint_func_name, cpp_func_name in zip(mint_func_names, cpp_class_names):
|
|
355
|
+
mint_func_reg_list.append(self.pybind_register_template.replace(mint_func_name=mint_func_name,
|
|
356
|
+
cpp_func_name=cpp_func_name))
|
|
357
|
+
return mint_func_reg_list
|
|
@@ -16,7 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
18
|
import logging
|
|
19
|
-
from gen_utils import safe_load_yaml
|
|
19
|
+
from common.gen_utils import safe_load_yaml
|
|
20
|
+
from common.op_proto import OpProto
|
|
21
|
+
import common.gen_constants as K
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
def is_optional_param(op_arg):
|
|
@@ -113,6 +115,54 @@ def get_convert_type_str(dtype: str, optional):
|
|
|
113
115
|
raise TypeError(f"""Unsupported convert type {dtype} for args.""")
|
|
114
116
|
|
|
115
117
|
|
|
118
|
+
def get_input_args_type_str(dtype: str, optional):
|
|
119
|
+
"""
|
|
120
|
+
Convert type
|
|
121
|
+
"""
|
|
122
|
+
# add more type here
|
|
123
|
+
native_type = {
|
|
124
|
+
'int': 'Int64ImmPtr',
|
|
125
|
+
'float': 'FP32ImmPtr',
|
|
126
|
+
'bool': 'BoolImmPtr',
|
|
127
|
+
'number': 'ScalarPtr',
|
|
128
|
+
'tuple[int]': 'ValueTuplePtr',
|
|
129
|
+
'tuple[float]': 'ValueTuplePtr',
|
|
130
|
+
'tuple[bool]': 'ValueTuplePtr',
|
|
131
|
+
'tuple[tensor]': 'ValueTuplePtr',
|
|
132
|
+
'list[int]': 'ValueTuplePtr',
|
|
133
|
+
'list[float]': 'ValueTuplePtr',
|
|
134
|
+
'list[bool]': 'ValueTuplePtr',
|
|
135
|
+
'list[tensor]': 'ValueTuplePtr',
|
|
136
|
+
'tensor': 'ValuePtr',
|
|
137
|
+
'str': 'StringImmPtr',
|
|
138
|
+
'type': 'Int64ImmPtr',
|
|
139
|
+
}
|
|
140
|
+
optional_type = {
|
|
141
|
+
'int': 'std::optional<Int64ImmPtr>',
|
|
142
|
+
'float': 'std::optional<FP32ImmPtr>',
|
|
143
|
+
'number': 'std::optional<ScalarPtr>',
|
|
144
|
+
'tensor': 'std::optional<ValuePtr>',
|
|
145
|
+
'type': 'std::optional<Int64ImmPtr>',
|
|
146
|
+
'str': 'std::optional<StringImmPtr>',
|
|
147
|
+
'tuple[int]': 'std::optional<ValueTuplePtr>',
|
|
148
|
+
'tuple[float]': 'std::optional<ValueTuplePtr>',
|
|
149
|
+
'tuple[bool]': 'std::optional<ValueTuplePtr>',
|
|
150
|
+
'tuple[tensor]': 'std::optional<ValueTuplePtr>',
|
|
151
|
+
'list[int]': 'std::optional<ValueTuplePtr>',
|
|
152
|
+
'list[float]': 'std::optional<ValueTuplePtr>',
|
|
153
|
+
'list[bool]': 'std::optional<ValueTuplePtr>',
|
|
154
|
+
'list[tensor]': 'std::optional<ValueTuplePtr>',
|
|
155
|
+
}
|
|
156
|
+
if optional:
|
|
157
|
+
if dtype in optional_type:
|
|
158
|
+
return optional_type[dtype]
|
|
159
|
+
raise TypeError(f"""Unknown optional type {dtype} for args.""")
|
|
160
|
+
if dtype in native_type:
|
|
161
|
+
return native_type[dtype]
|
|
162
|
+
raise TypeError(f"""Unknown type {dtype} for args.""")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
116
166
|
def get_value_convert_type_str(dtype: str, optional):
|
|
117
167
|
"""
|
|
118
168
|
Convert type
|
|
@@ -163,11 +213,11 @@ def tuple_input_to_cpp_type(dtype: str):
|
|
|
163
213
|
'tuple[float]': 'float',
|
|
164
214
|
'tuple[bool]': 'bool',
|
|
165
215
|
'tuple[str]': 'string',
|
|
166
|
-
'tuple[tensor]': 'TensorPtr',
|
|
216
|
+
'tuple[tensor]': 'mindspore::tensor::TensorPtr',
|
|
167
217
|
'list[int]': 'int64_t',
|
|
168
218
|
'list[float]': 'float',
|
|
169
219
|
'list[bool]': 'bool',
|
|
170
|
-
'list[tensor]': 'TensorPtr',
|
|
220
|
+
'list[tensor]': 'mindspore::tensor::TensorPtr',
|
|
171
221
|
}
|
|
172
222
|
return types_map.get(dtype)
|
|
173
223
|
|
|
@@ -187,14 +237,14 @@ def get_input_dtype(dtype: str, optional):
|
|
|
187
237
|
Convert type
|
|
188
238
|
"""
|
|
189
239
|
# add more type here
|
|
190
|
-
value_tuple = 'ValueTuplePtr'
|
|
240
|
+
value_tuple = 'mindspore::ValueTuplePtr'
|
|
191
241
|
type_convert = {
|
|
192
|
-
'int': 'Int64ImmPtr',
|
|
193
|
-
'float': 'FP32ImmPtr',
|
|
194
|
-
'bool': 'BoolImmPtr',
|
|
195
|
-
'number': 'ScalarPtr',
|
|
196
|
-
'str': 'StringImmPtr',
|
|
197
|
-
'tensor': 'BaseTensorPtr',
|
|
242
|
+
'int': 'mindspore::Int64ImmPtr',
|
|
243
|
+
'float': 'mindspore::FP32ImmPtr',
|
|
244
|
+
'bool': 'mindspore::BoolImmPtr',
|
|
245
|
+
'number': 'mindspore::ScalarPtr',
|
|
246
|
+
'str': 'mindspore::StringImmPtr',
|
|
247
|
+
'tensor': 'mindspore::tensor::BaseTensorPtr',
|
|
198
248
|
'tuple[int]': value_tuple,
|
|
199
249
|
'tuple[float]': value_tuple,
|
|
200
250
|
'tuple[bool]': value_tuple,
|
|
@@ -204,14 +254,14 @@ def get_input_dtype(dtype: str, optional):
|
|
|
204
254
|
'list[bool]': value_tuple,
|
|
205
255
|
'list[tensor]': value_tuple,
|
|
206
256
|
}
|
|
207
|
-
value_tuple_optional = 'std::optional<ValueTuplePtr>'
|
|
257
|
+
value_tuple_optional = 'std::optional<mindspore::ValueTuplePtr>'
|
|
208
258
|
optional_type_convert = {
|
|
209
|
-
'int': 'std::optional<Int64ImmPtr>',
|
|
210
|
-
'float': 'std::optional<FP32ImmPtr>',
|
|
211
|
-
'bool': 'std::optional<BoolImmPtr>',
|
|
212
|
-
'number': 'std::optional<ScalarPtr>',
|
|
213
|
-
'str': 'std::optional<StringImmPtr>',
|
|
214
|
-
'tensor': 'std::optional<BaseTensorPtr>',
|
|
259
|
+
'int': 'std::optional<mindspore::Int64ImmPtr>',
|
|
260
|
+
'float': 'std::optional<mindspore::FP32ImmPtr>',
|
|
261
|
+
'bool': 'std::optional<mindspore::BoolImmPtr>',
|
|
262
|
+
'number': 'std::optional<mindspore::ScalarPtr>',
|
|
263
|
+
'str': 'std::optional<mindspore::StringImmPtr>',
|
|
264
|
+
'tensor': 'std::optional<mindspore::tensor::BaseTensorPtr>',
|
|
215
265
|
'tuple[int]': value_tuple_optional,
|
|
216
266
|
'tuple[float]': value_tuple_optional,
|
|
217
267
|
'tuple[bool]': value_tuple_optional,
|
|
@@ -239,9 +289,9 @@ def get_return_type(dtype: str):
|
|
|
239
289
|
"""
|
|
240
290
|
# add more type here
|
|
241
291
|
type_convert = {
|
|
242
|
-
'tuple[tensor]': 'std::vector<tensor::TensorPtr>',
|
|
243
|
-
'list[tensor]': 'std::vector<tensor::TensorPtr>',
|
|
244
|
-
'tensor': 'tensor::TensorPtr',
|
|
292
|
+
'tuple[tensor]': 'std::vector<mindspore::tensor::TensorPtr>',
|
|
293
|
+
'list[tensor]': 'std::vector<mindspore::tensor::TensorPtr>',
|
|
294
|
+
'tensor': 'mindspore::tensor::TensorPtr',
|
|
245
295
|
}
|
|
246
296
|
if dtype in type_convert:
|
|
247
297
|
return type_convert[dtype]
|
|
@@ -266,11 +316,10 @@ def get_op_name(operator_name, class_def):
|
|
|
266
316
|
"""
|
|
267
317
|
Get op name for python class Primitive or c++ OpDef name.
|
|
268
318
|
"""
|
|
319
|
+
if class_def:
|
|
320
|
+
return class_def
|
|
321
|
+
|
|
269
322
|
class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
|
|
270
|
-
if class_def is not None:
|
|
271
|
-
item = class_def.get("name")
|
|
272
|
-
if item is not None:
|
|
273
|
-
class_name = item
|
|
274
323
|
return class_name
|
|
275
324
|
|
|
276
325
|
|
|
@@ -278,10 +327,6 @@ def get_pyboost_name(operator_name):
|
|
|
278
327
|
return 'pyboost_' + operator_name
|
|
279
328
|
|
|
280
329
|
|
|
281
|
-
def convert_python_func_name_to_c(func_name: str) -> str:
|
|
282
|
-
return ''.join(word.capitalize() for word in func_name.split('_'))
|
|
283
|
-
|
|
284
|
-
|
|
285
330
|
def get_const_number_convert(arg_name, op_arg):
|
|
286
331
|
cpp_type = number_input_to_cpp_type(op_arg.arg_dtype)
|
|
287
332
|
if op_arg.is_type_id:
|
|
@@ -297,8 +342,8 @@ def get_tuple_input_convert(arg_name, arg_type):
|
|
|
297
342
|
:return:
|
|
298
343
|
"""
|
|
299
344
|
cpp_type = tuple_input_to_cpp_type(arg_type)
|
|
300
|
-
if cpp_type == "TensorPtr":
|
|
301
|
-
cpp_type = "BaseTensorPtr"
|
|
345
|
+
if cpp_type == "mindspore::tensor::TensorPtr":
|
|
346
|
+
cpp_type = "mindspore::tensor::BaseTensorPtr"
|
|
302
347
|
return f"std::vector<{cpp_type}> {arg_name}_vector = ConvertValueTupleToVector<{cpp_type}>({arg_name});\n"
|
|
303
348
|
|
|
304
349
|
|
|
@@ -311,12 +356,52 @@ def is_pyboost_enable(operator_data):
|
|
|
311
356
|
return False
|
|
312
357
|
|
|
313
358
|
|
|
359
|
+
def format_func_api_name(func_api_name):
|
|
360
|
+
"""
|
|
361
|
+
Converts a snake_case string to PascalCase format with the first letter capitalized.
|
|
362
|
+
Additionally, it preserves the trailing underscore. In special cases, such as double
|
|
363
|
+
underscore names (e.g., __add__), it converts them into PascalCase.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
func_api_name (str): The input snake_case string.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
str: The converted PascalCase string.
|
|
370
|
+
"""
|
|
371
|
+
# Check if the string ends with '_'
|
|
372
|
+
is_one_underscore = func_api_name.endswith('_')
|
|
373
|
+
|
|
374
|
+
# Check if it is a double-underscore name (special method names)
|
|
375
|
+
is_double_underscore = func_api_name.startswith('__') and func_api_name.endswith('__')
|
|
376
|
+
|
|
377
|
+
# If it is a double-underscore name, remove the leading and trailing underscores
|
|
378
|
+
if is_double_underscore:
|
|
379
|
+
func_api_name = func_api_name[2:-2]
|
|
380
|
+
|
|
381
|
+
# If the original name ends with '_' but is not a double-underscore name, remove the trailing '_'
|
|
382
|
+
if is_one_underscore and not is_double_underscore:
|
|
383
|
+
func_api_name = func_api_name[:-1]
|
|
384
|
+
|
|
385
|
+
# Convert snake_case to PascalCase
|
|
386
|
+
formatted_func_api_name = ''.join(x.capitalize() for x in func_api_name.split('_'))
|
|
387
|
+
|
|
388
|
+
# If the original name ends with '_' but is not a double-underscore name, append the trailing underscore
|
|
389
|
+
if is_one_underscore and not is_double_underscore:
|
|
390
|
+
formatted_func_api_name += '_'
|
|
391
|
+
|
|
392
|
+
# If the original name is a double-underscore name, add a 'Magic' suffix.
|
|
393
|
+
if is_double_underscore:
|
|
394
|
+
formatted_func_api_name += 'Magic'
|
|
395
|
+
|
|
396
|
+
return formatted_func_api_name
|
|
397
|
+
|
|
398
|
+
|
|
314
399
|
def convert_types(inputs):
|
|
315
400
|
'''convert type to acl type'''
|
|
316
401
|
inputs_dtypes = {}
|
|
317
402
|
flag = False
|
|
318
403
|
for i in inputs:
|
|
319
|
-
inputs_dtypes[i] =
|
|
404
|
+
inputs_dtypes[i] = i.arg_dtype
|
|
320
405
|
if inputs_dtypes[i] != 'tensor':
|
|
321
406
|
flag = True
|
|
322
407
|
if 'tuple' in inputs_dtypes[i]:
|
|
@@ -338,22 +423,80 @@ def convert_types(inputs):
|
|
|
338
423
|
return inputs_dtypes, flag
|
|
339
424
|
|
|
340
425
|
|
|
341
|
-
def get_dtypes(
|
|
426
|
+
def get_dtypes(op_proto: OpProto):
|
|
342
427
|
"""get op inputs and outputs dtypes"""
|
|
343
|
-
inputs =
|
|
344
|
-
outputs =
|
|
428
|
+
inputs = op_proto.op_args
|
|
429
|
+
outputs = op_proto.op_returns
|
|
345
430
|
inputs_dtypes, flag_in = convert_types(inputs)
|
|
346
431
|
outputs_dtypes, flag_out = convert_types(outputs)
|
|
347
432
|
none_tensor_exist = (flag_in or flag_out)
|
|
348
433
|
return inputs_dtypes, outputs_dtypes, none_tensor_exist
|
|
349
434
|
|
|
350
435
|
|
|
436
|
+
def merge_strings_by_chunk_size(string_list, chunk_size=50):
|
|
437
|
+
"""
|
|
438
|
+
Merges a list of strings into smaller chunks, with each chunk having a specified maximum size.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
string_list (list of str): A list of strings to be merged.
|
|
442
|
+
chunk_size (int, optional): The maximum size of each merged chunk. Defaults to 50.
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
list of str: A list of merged strings, where each string contains up to `chunk_size` characters.
|
|
446
|
+
|
|
447
|
+
Example:
|
|
448
|
+
>>> strings = ["Hello", "world", "this", "is", "a", "test"]
|
|
449
|
+
>>> merge_strings_by_chunk_size(strings, chunk_size=2)
|
|
450
|
+
['Helloworld', 'thisis', 'atest']
|
|
451
|
+
"""
|
|
452
|
+
merged_strings = [
|
|
453
|
+
"".join(string_list[i:i + chunk_size]) # Merge the current grouped string
|
|
454
|
+
for i in range(0, len(string_list), chunk_size)
|
|
455
|
+
]
|
|
456
|
+
return merged_strings
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def chunk_list(lst, n):
|
|
460
|
+
"""
|
|
461
|
+
Divide a list into sublists of length 'n'.
|
|
462
|
+
|
|
463
|
+
This function takes a list `lst` and an integer `n`, and returns a new list
|
|
464
|
+
where each element is a sublist of `lst` containing up to `n` elements.
|
|
465
|
+
If the length of `lst` is not a multiple of `n`, the last sublist will contain
|
|
466
|
+
fewer than `n` elements.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
lst (list): The original list to be divided.
|
|
470
|
+
n (int): The number of elements per sublist.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
list: A list of sublists, where each sublist has up to `n` elements.
|
|
474
|
+
|
|
475
|
+
Raises:
|
|
476
|
+
ValueError: If `n` is less than or equal to 0.
|
|
477
|
+
|
|
478
|
+
Example:
|
|
479
|
+
>>> my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
480
|
+
>>> chunked_list = chunk_list(my_list, 3)
|
|
481
|
+
>>> print(chunked_list)
|
|
482
|
+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
|
483
|
+
|
|
484
|
+
Note:
|
|
485
|
+
This function assumes that `n` is a positive integer. If `n` is not a
|
|
486
|
+
positive integer, a ValueError is raised.
|
|
487
|
+
"""
|
|
488
|
+
if n <= 0:
|
|
489
|
+
raise ValueError("The chunk size 'n' must be a positive integer.")
|
|
490
|
+
|
|
491
|
+
return [lst[i:i + n] for i in range(0, len(lst), n)]
|
|
492
|
+
|
|
493
|
+
|
|
351
494
|
class AclnnUtils:
|
|
352
495
|
"""
|
|
353
496
|
aclnn utils
|
|
354
497
|
"""
|
|
355
|
-
|
|
356
|
-
|
|
498
|
+
aclnn_map = safe_load_yaml(os.path.join(
|
|
499
|
+
K.WORK_DIR, K.PY_OPS_GEN_PATH, "pyboost/aclnn_config.yaml"))
|
|
357
500
|
|
|
358
501
|
@staticmethod
|
|
359
502
|
def get_aclnn_interface(class_name):
|
|
File without changes
|