mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generates C++ functional map header files for graph mode.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
import common.gen_constants as K
|
|
22
|
+
import common.template as template
|
|
23
|
+
from common.gen_utils import save_file, OrderedSet
|
|
24
|
+
from common.base_generator import BaseGenerator
|
|
25
|
+
from pyboost import pyboost_utils
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FunctionalMapCppGenerator(BaseGenerator):
|
|
29
|
+
"""
|
|
30
|
+
Generates C++ functional map header files for graph mode.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
"""
|
|
35
|
+
Initializes the generator with templates for the functional map.
|
|
36
|
+
"""
|
|
37
|
+
self.FUNCTIONAL_MAP_CC_TEMPLATE = template.FUNCTIONAL_MAP_CC_TEMPLATE
|
|
38
|
+
self.FUNCTIONAL_MAP_H_TEMPLATE = template.FUNCTIONAL_MAP_H_TEMPLATE
|
|
39
|
+
self.class_to_method_template = template.Template("{\"${class_name}\", \"${method_name}\"}")
|
|
40
|
+
self.functional_map_template = template.Template("{\"${func_api_name}\", {${class_to_method_str}}},")
|
|
41
|
+
self.k_prim_op_template = template.Template("prim::kPrim${camel_op_name}")
|
|
42
|
+
self.tensor_method_kwonlyargs_map_template = template.Template(
|
|
43
|
+
"{\"${camel_op_name}\", {${kw_only_args_list}}},")
|
|
44
|
+
self.tensor_method_varargs_map_template = \
|
|
45
|
+
template.Template("{\"${op_name}\", ${vararg_index}},")
|
|
46
|
+
self.deprecated_method_decl_template = template.Template(
|
|
47
|
+
"auto ${dep_op_name} = std::make_shared<prim::DeprecatedTensorMethod>(\"${dep_op_name}\", \"${op_name}\");")
|
|
48
|
+
self.functional_method_map_template = template.Template("{\"${op_name}\", {${sort_func_method_list_str}}},")
|
|
49
|
+
|
|
50
|
+
self.arg_handler_map = {"to_2d_paddings": ["tuple[int]", "list[int]", "int"],
|
|
51
|
+
"dtype_to_type_id": ["int", "type"],
|
|
52
|
+
"to_kernel_size": ["tuple[int]", "list[int]", "int"],
|
|
53
|
+
"to_strides": ["tuple[int]", "list[int]", "int"],
|
|
54
|
+
"str_to_enum": ["str"],
|
|
55
|
+
"to_pair": ["tuple[int]", "list[int]", "int", "float"],
|
|
56
|
+
"to_dilations": ["tuple[int]", "list[int]", "int"],
|
|
57
|
+
"to_output_padding": ["tuple[int]", "list[int]", "int"],
|
|
58
|
+
"to_rates": ["tuple[int]", "list[int]", "int"]}
|
|
59
|
+
|
|
60
|
+
self.prompt_type_map = {"any": "any",
|
|
61
|
+
"int": "int",
|
|
62
|
+
"float": "float",
|
|
63
|
+
"str": "str",
|
|
64
|
+
"bool": "bool",
|
|
65
|
+
"number": "number",
|
|
66
|
+
"tensor": "Tensor",
|
|
67
|
+
"type": "mstype",
|
|
68
|
+
"None": "None"}
|
|
69
|
+
|
|
70
|
+
def generate(self, work_path, tensor_method_protos_data, mint_func_protos_data, alias_func_mapping):
|
|
71
|
+
"""
|
|
72
|
+
Generates the functional map header file.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
work_path (str): The directory path to save the generated file.
|
|
76
|
+
tensor_method_protos_data (dict): A dictionary mapping function API names to their prototype data.
|
|
77
|
+
mint_func_protos_data (dict): A dictionary mapping mint API names to their prototype data.
|
|
78
|
+
alias_func_mapping (dict): A dictionary mapping function name to its alias function names.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
None
|
|
82
|
+
"""
|
|
83
|
+
ops_inc_head_set = set()
|
|
84
|
+
dep_method_decl_list = self._get_dep_method_decl_list(tensor_method_protos_data)
|
|
85
|
+
tensor_method_overload_list, op_inc_1 = self._get_functional_method_map(tensor_method_protos_data,
|
|
86
|
+
alias_func_mapping)
|
|
87
|
+
mint_overload_list, op_inc_2 = self._get_functional_mint_map(mint_func_protos_data, alias_func_mapping)
|
|
88
|
+
tensor_method_kw_only_args_list = self._get_tensor_method_kwonlyargs_map(tensor_method_protos_data)
|
|
89
|
+
mint_kw_only_args_list = self._get_mint_kwonlyargs_map(mint_func_protos_data, alias_func_mapping)
|
|
90
|
+
tensor_varargs_map_list = self._get_tensor_varargs_map_list(tensor_method_protos_data)
|
|
91
|
+
mint_varargs_map_list = self._get_mint_varargs_map_list(mint_func_protos_data, alias_func_mapping)
|
|
92
|
+
funcs_sig_map_list = (
|
|
93
|
+
self._get_func_sigs_list(tensor_method_protos_data, alias_func_mapping, is_tensor_method=True))
|
|
94
|
+
funcs_mint_sigs_map = (
|
|
95
|
+
self._get_func_sigs_list(mint_func_protos_data, alias_func_mapping, is_tensor_method=False))
|
|
96
|
+
merge_op_inc = op_inc_1 | op_inc_2
|
|
97
|
+
for op_inc in merge_op_inc:
|
|
98
|
+
ops_inc_head_set.add(template.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_inc[0].lower()))
|
|
99
|
+
functional_map_cc_code = (
|
|
100
|
+
self.FUNCTIONAL_MAP_CC_TEMPLATE.replace(ops_inc=list(sorted(ops_inc_head_set)),
|
|
101
|
+
deprecated_method_decl=dep_method_decl_list,
|
|
102
|
+
tensor_method_map=tensor_method_overload_list,
|
|
103
|
+
mint_map=mint_overload_list,
|
|
104
|
+
tensor_method_kwonlyargs_map=tensor_method_kw_only_args_list,
|
|
105
|
+
mint_kwonlyargs_map=mint_kw_only_args_list,
|
|
106
|
+
tensor_varargs_map=tensor_varargs_map_list,
|
|
107
|
+
mint_varargs_map=mint_varargs_map_list,
|
|
108
|
+
tensor_method_sigs_map=funcs_sig_map_list,
|
|
109
|
+
mint_sigs_map=funcs_mint_sigs_map))
|
|
110
|
+
save_path = os.path.join(work_path, K.FUNCTIONAL_OVERLOAD_GEN_PATH)
|
|
111
|
+
save_file(save_path, "functional_map.cc", functional_map_cc_code)
|
|
112
|
+
save_file(save_path, "functional_map.h", self.FUNCTIONAL_MAP_H_TEMPLATE.replace())
|
|
113
|
+
|
|
114
|
+
def _get_func_sigs_list(self, tensor_method_protos_data, alias_func_mapping, is_tensor_method):
|
|
115
|
+
"""
|
|
116
|
+
Generates a list of function signatures for each function API name based on the provided prototype data.
|
|
117
|
+
|
|
118
|
+
Args: tensor_method_protos_data (dict): A dictionary mapping function API names to their corresponding prototype
|
|
119
|
+
data. Each prototype contains information necessary to generate function signatures.
|
|
120
|
+
alias_func_mapping (dict): A dictionary mapping function name to its alias function names.
|
|
121
|
+
is_tensor_method (bool): Whether the prototype data is a tensor method or a mint function.
|
|
122
|
+
|
|
123
|
+
Returns: list: A list of function signature strings for each function API, which are generated based on the
|
|
124
|
+
prototype data.
|
|
125
|
+
"""
|
|
126
|
+
funcs_list = []
|
|
127
|
+
for func_api_name, func_protos in tensor_method_protos_data.items():
|
|
128
|
+
func_signatures = self._generate_func_signatures_str(func_api_name, func_protos, is_tensor_method)
|
|
129
|
+
funcs_list.append(func_signatures)
|
|
130
|
+
if func_api_name in alias_func_mapping:
|
|
131
|
+
for alias_api_name in alias_func_mapping[func_api_name]:
|
|
132
|
+
func_signatures = self._generate_func_signatures_str(alias_api_name, func_protos, is_tensor_method)
|
|
133
|
+
funcs_list.append(func_signatures)
|
|
134
|
+
|
|
135
|
+
return funcs_list
|
|
136
|
+
|
|
137
|
+
def _generate_func_signatures_str(self, func_api_name, func_protos, is_tensor_method) -> str:
|
|
138
|
+
"""
|
|
139
|
+
Generates function signatures as a string from the given prototypes.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
func_api_name (str): The name of the API to generate signatures for.
|
|
143
|
+
func_protos (list): List of TensorFuncProto objects representing the function prototypes.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
str: Generated function signatures string.
|
|
147
|
+
"""
|
|
148
|
+
sig_set = OrderedSet()
|
|
149
|
+
for tensor_proto in func_protos:
|
|
150
|
+
sig_set.add(self._generate_single_signature_str(func_api_name, tensor_proto, is_tensor_method))
|
|
151
|
+
sig_list = list(sig_set)
|
|
152
|
+
sig_str = '{' + f'\"{func_api_name}\",\n ' + '{'
|
|
153
|
+
first_sig = True
|
|
154
|
+
for sig in sig_list:
|
|
155
|
+
if not first_sig:
|
|
156
|
+
sig_str += ',\n'
|
|
157
|
+
first_sig = False
|
|
158
|
+
sig_str += sig
|
|
159
|
+
sig_str += '}\n},'
|
|
160
|
+
return sig_str
|
|
161
|
+
|
|
162
|
+
def _is_input_arg(self, arg_name, op_name):
|
|
163
|
+
res = False
|
|
164
|
+
if op_name in K.INPUT_NAME_MAP and arg_name == K.INPUT_NAME_MAP[op_name]:
|
|
165
|
+
res = True
|
|
166
|
+
elif op_name not in K.INPUT_NAME_MAP and arg_name in K.INPUT_ARGS_NAME:
|
|
167
|
+
res = True
|
|
168
|
+
return res
|
|
169
|
+
|
|
170
|
+
def _generate_single_signature_str(self, func_api_name, tensor_proto, is_tensor_method) -> str:
|
|
171
|
+
"""
|
|
172
|
+
Generates a single function signature string for the given operation prototype.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
func_api_name (str): The name of the API to generate signatures for.
|
|
176
|
+
tensor_proto (OpProto): TensorFuncProto objects representing the function prototypes.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
str: Generated function signature string.
|
|
180
|
+
"""
|
|
181
|
+
op_proto = tensor_proto.op_proto
|
|
182
|
+
op_name = tensor_proto.op_proto.op_class.name
|
|
183
|
+
args_str = f'"Tensor.{func_api_name}(' if is_tensor_method else f'"{func_api_name}('
|
|
184
|
+
first_arg = True
|
|
185
|
+
is_kw_args_init = False
|
|
186
|
+
arg_valid_types = []
|
|
187
|
+
for _, arg in enumerate(op_proto.op_args):
|
|
188
|
+
arg_name = arg.arg_name
|
|
189
|
+
if is_tensor_method and self._is_input_arg(arg_name, op_name):
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
arg_valid_types = self._handle_arg_valid_types(arg, arg_name, arg_valid_types, func_api_name)
|
|
193
|
+
single_arg = f'{arg_name}=<' + ','.join(arg_valid_types) + '>'
|
|
194
|
+
prefix, is_kw_args_init, first_arg = self._build_prefix(arg_name, first_arg, is_kw_args_init, tensor_proto)
|
|
195
|
+
args_str += prefix + single_arg
|
|
196
|
+
arg_valid_types = []
|
|
197
|
+
return args_str + ')"'
|
|
198
|
+
|
|
199
|
+
def _build_prefix(self, arg_name, first_arg, is_kw_args_init, tensor_proto):
|
|
200
|
+
"""
|
|
201
|
+
Build and return the prefix for the current argument, handling insertion of kw-only args if needed.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
arg_name (str): Name of the current argument.
|
|
205
|
+
first_arg (bool): Indicates whether this is the first argument.
|
|
206
|
+
is_kw_args_init (bool): Indicates whether kw-only args prefix has been inserted.
|
|
207
|
+
tensor_proto: Contains information about kw_only_args.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
tuple:
|
|
211
|
+
prefix (str): Generated prefix (possibly including '*, ').
|
|
212
|
+
is_kw_args_init (bool): Updated kw-only args insertion status.
|
|
213
|
+
first_arg (bool): Updated first-argument status.
|
|
214
|
+
"""
|
|
215
|
+
prefix = "" if first_arg else ", "
|
|
216
|
+
if tensor_proto.kw_only_args and not is_kw_args_init and arg_name == tensor_proto.kw_only_args[0]:
|
|
217
|
+
prefix += "*, "
|
|
218
|
+
is_kw_args_init = True
|
|
219
|
+
if first_arg:
|
|
220
|
+
first_arg = False
|
|
221
|
+
return prefix, is_kw_args_init, first_arg
|
|
222
|
+
|
|
223
|
+
def _handle_arg_valid_types(self, arg, arg_name, arg_valid_types, func_api_name):
|
|
224
|
+
"""
|
|
225
|
+
Collect and return valid argument types based on the arg handler, defaults, and casts.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
arg: Argument object containing handler, dtype, and cast info.
|
|
229
|
+
arg_name (str): Name of the current argument.
|
|
230
|
+
arg_valid_types (list): Existing valid types to be extended.
|
|
231
|
+
func_api_name (str): Name of the current API function.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
list: Sorted list of valid argument types (descending order).
|
|
235
|
+
"""
|
|
236
|
+
arg_handler = arg.arg_handler
|
|
237
|
+
if arg_handler != '':
|
|
238
|
+
if arg_handler in self.arg_handler_map:
|
|
239
|
+
arg_valid_types.extend(self.arg_handler_map[arg_handler])
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError("Generate failed. Check if {} is registered in TensorFuncRegCppGenerator."
|
|
242
|
+
.format(arg_handler))
|
|
243
|
+
else:
|
|
244
|
+
arg_valid_types.append(arg.arg_dtype)
|
|
245
|
+
for cast_type in arg.type_cast:
|
|
246
|
+
arg_valid_types.append(cast_type)
|
|
247
|
+
|
|
248
|
+
if arg.as_init_arg and str(arg.default) == 'None':
|
|
249
|
+
arg_valid_types.append('None')
|
|
250
|
+
|
|
251
|
+
arg_valid_types = self._parse_arg_type_list(func_api_name, arg_name, arg_valid_types)
|
|
252
|
+
|
|
253
|
+
return sorted(arg_valid_types, reverse=True)
|
|
254
|
+
|
|
255
|
+
def _parse_arg_type_list(self, func_api_name, arg_name, arg_valid_types):
|
|
256
|
+
"""
|
|
257
|
+
Parses a list of argument types and maps them to generalized types.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
func_api_name (str): The name of the function API for which the argument types are being parsed.
|
|
261
|
+
arg_name (str): The name of the argument whose valid types are being generalized.
|
|
262
|
+
arg_valid_types (list): A list of valid argument types that need to be generalized.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
set: A set of generalized argument types (e.g., 'List', 'Tuple') based on the input types.
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
ValueError: If an unrecognized or invalid type is encountered in the argument types list.
|
|
269
|
+
"""
|
|
270
|
+
generalized_type_list = set()
|
|
271
|
+
for arg_type in arg_valid_types:
|
|
272
|
+
if arg_type in self.prompt_type_map:
|
|
273
|
+
generalized_type_list.add(self.prompt_type_map[arg_type])
|
|
274
|
+
elif "list" in arg_type:
|
|
275
|
+
generalized_type_list.add('List')
|
|
276
|
+
elif "tuple" in arg_type:
|
|
277
|
+
generalized_type_list.add('Tuple')
|
|
278
|
+
else:
|
|
279
|
+
raise ValueError(f"Invalid type {arg_type} in api: {func_api_name} {arg_name}.")
|
|
280
|
+
return generalized_type_list
|
|
281
|
+
|
|
282
|
+
def _get_dep_method_decl_list(self, func_protos_data):
|
|
283
|
+
"""
|
|
284
|
+
Extracts and generates declarations for deprecated methods from the provided function prototypes.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
func_protos_data (dict): A dictionary where keys are function API names and values are lists
|
|
288
|
+
of function prototypes. Each prototype contains an operation name.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
list: A list of strings, each representing a declaration for a deprecated method.
|
|
292
|
+
"""
|
|
293
|
+
deprecated_method_decl_list = []
|
|
294
|
+
for func_api_name, func_protos in func_protos_data.items():
|
|
295
|
+
for func_proto in func_protos:
|
|
296
|
+
op_name = func_proto.op_proto.op_name
|
|
297
|
+
if not op_name.startswith("deprecated"):
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
deprecated_method_name = ''.join(word.capitalize() for word in op_name.split('_'))
|
|
301
|
+
if func_proto.op_proto.op_name[-1] == '_':
|
|
302
|
+
deprecated_method_name += '_'
|
|
303
|
+
deprecated_method_decl_list.append(
|
|
304
|
+
self.deprecated_method_decl_template.replace(dep_op_name=deprecated_method_name,
|
|
305
|
+
op_name=func_api_name))
|
|
306
|
+
|
|
307
|
+
return deprecated_method_decl_list
|
|
308
|
+
|
|
309
|
+
def _get_functional_method_map(self, tensor_method_protos_data, alias_func_mapping):
|
|
310
|
+
"""
|
|
311
|
+
Generates a list of functional method maps from the provided function prototypes and alias mappings.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
tensor_method_protos_data (dict): A dictionary where keys are function API names and values are lists
|
|
315
|
+
of function prototypes.
|
|
316
|
+
alias_func_mapping (dict): A dictionary mapping function API names to their aliases.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
list: A list of strings, each representing a functional method map.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
op_inc_set = set()
|
|
323
|
+
|
|
324
|
+
def get_sort_func_method_list(func_protos):
|
|
325
|
+
"""
|
|
326
|
+
Retrieves a sorted list of operator primitives, prioritizing deprecated operators.
|
|
327
|
+
"""
|
|
328
|
+
func_method_list = []
|
|
329
|
+
for func_proto in func_protos:
|
|
330
|
+
k_op_name = pyboost_utils.get_op_name(func_proto.op_proto.op_name, func_proto.op_proto.op_class.name)
|
|
331
|
+
if k_op_name.startswith("Deprecated"):
|
|
332
|
+
func_method_list.append(k_op_name)
|
|
333
|
+
else:
|
|
334
|
+
func_method_list.append(self.k_prim_op_template.replace(camel_op_name=k_op_name))
|
|
335
|
+
op_inc_set.add(k_op_name)
|
|
336
|
+
|
|
337
|
+
func_method_list.sort(key=lambda x: x.startswith("Deprecated"), reverse=True)
|
|
338
|
+
return func_method_list
|
|
339
|
+
|
|
340
|
+
deprecated_method_decl_list = []
|
|
341
|
+
for func_api_name, func_protos in tensor_method_protos_data.items():
|
|
342
|
+
sort_func_method_list = get_sort_func_method_list(func_protos)
|
|
343
|
+
deprecated_method_decl_list.append(
|
|
344
|
+
self.functional_method_map_template.replace(op_name=func_api_name,
|
|
345
|
+
sort_func_method_list_str=sort_func_method_list))
|
|
346
|
+
|
|
347
|
+
if func_api_name in alias_func_mapping:
|
|
348
|
+
for alias in alias_func_mapping[func_api_name]:
|
|
349
|
+
deprecated_method_decl_list.append(
|
|
350
|
+
self.functional_method_map_template.replace(op_name=alias,
|
|
351
|
+
sort_func_method_list_str=sort_func_method_list))
|
|
352
|
+
|
|
353
|
+
return deprecated_method_decl_list, op_inc_set
|
|
354
|
+
|
|
355
|
+
def _get_functional_mint_map(self, mint_func_protos_data, alias_func_mapping):
|
|
356
|
+
"""
|
|
357
|
+
mint_func_protos_data (dict): A dictionary mapping mint API names to their prototype data.
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
op_inc_set = set()
|
|
361
|
+
|
|
362
|
+
def get_mint_func_list(func_protos):
|
|
363
|
+
"""
|
|
364
|
+
Retrieves a sorted list of operator primitives, prioritizing deprecated operators.
|
|
365
|
+
"""
|
|
366
|
+
func_method_list = []
|
|
367
|
+
for func_proto in func_protos:
|
|
368
|
+
k_op_name = pyboost_utils.get_op_name(func_proto.op_proto.op_name, func_proto.op_proto.op_class.name)
|
|
369
|
+
func_method_list.append(self.k_prim_op_template.replace(camel_op_name=k_op_name))
|
|
370
|
+
op_inc_set.add(k_op_name)
|
|
371
|
+
|
|
372
|
+
return func_method_list
|
|
373
|
+
|
|
374
|
+
mint_func_decl_list = []
|
|
375
|
+
for func_api_name, func_protos in mint_func_protos_data.items():
|
|
376
|
+
mint_func_list = get_mint_func_list(func_protos)
|
|
377
|
+
mint_func_decl_list.append(
|
|
378
|
+
self.functional_method_map_template.replace(op_name=func_api_name,
|
|
379
|
+
sort_func_method_list_str=mint_func_list))
|
|
380
|
+
if func_api_name in alias_func_mapping:
|
|
381
|
+
for alias in alias_func_mapping[func_api_name]:
|
|
382
|
+
mint_func_decl_list.append(
|
|
383
|
+
self.functional_method_map_template.replace(op_name=alias,
|
|
384
|
+
sort_func_method_list_str=mint_func_list))
|
|
385
|
+
return mint_func_decl_list, op_inc_set
|
|
386
|
+
|
|
387
|
+
def _get_and_append_single_op_kw_only_args_list(self, func_protos, single_op_kw_only_args_list):
|
|
388
|
+
"""
|
|
389
|
+
Extracts keyword-only arguments from a list of function prototypes and appends them to a list.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
func_protos (list): A list of function prototypes.
|
|
393
|
+
single_op_kw_only_args_list (list): The list to append the keyword-only arguments to.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
None
|
|
397
|
+
"""
|
|
398
|
+
for func_proto in func_protos:
|
|
399
|
+
camel_op_name = pyboost_utils.get_op_name(func_proto.op_proto.op_name, func_proto.op_proto.op_class.name)
|
|
400
|
+
kw_only_args = func_proto.kw_only_args
|
|
401
|
+
if kw_only_args:
|
|
402
|
+
kw_only_args_list = ", ".join(f"\"{kw_arg}\"" for kw_arg in kw_only_args)
|
|
403
|
+
single_op_kw_only_args_list.append(
|
|
404
|
+
self.tensor_method_kwonlyargs_map_template.replace(camel_op_name=camel_op_name,
|
|
405
|
+
kw_only_args_list=kw_only_args_list)
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
def _get_tensor_method_kwonlyargs_map(self, tensor_method_protos_data):
|
|
409
|
+
"""
|
|
410
|
+
Generates a list of keyword-only arguments for tensor methods.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
tensor_method_protos_data (dict): A dictionary of tensor method prototype data.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
list: A list of formatted strings representing the keyword-only arguments.
|
|
417
|
+
"""
|
|
418
|
+
tensor_method_kw_only_args_list = []
|
|
419
|
+
for _, func_protos in tensor_method_protos_data.items():
|
|
420
|
+
self._get_and_append_single_op_kw_only_args_list(func_protos,
|
|
421
|
+
tensor_method_kw_only_args_list)
|
|
422
|
+
return tensor_method_kw_only_args_list
|
|
423
|
+
|
|
424
|
+
def _get_mint_kwonlyargs_map(self, mint_func_protos_data, alias_func_mapping):
|
|
425
|
+
"""
|
|
426
|
+
Generates a list of keyword-only arguments for mint functions.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
mint_func_protos_data (dict): A dictionary of mint function prototype data.
|
|
430
|
+
alias_func_mapping (dict): A dictionary mapping original function names to alias function names.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
list: A list of formatted strings representing the keyword-only arguments.
|
|
434
|
+
"""
|
|
435
|
+
mint_kw_only_args_list = []
|
|
436
|
+
for _, func_protos in mint_func_protos_data.items():
|
|
437
|
+
self._get_and_append_single_op_kw_only_args_list(func_protos,
|
|
438
|
+
mint_kw_only_args_list)
|
|
439
|
+
return mint_kw_only_args_list
|
|
440
|
+
|
|
441
|
+
def _get_and_append_single_op_varargs_list(self, func_protos, single_op_varargs_list):
|
|
442
|
+
"""
|
|
443
|
+
Extracts variable arguments from a list of function prototypes and appends them to a list.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
func_protos (list): A list of function prototypes.
|
|
447
|
+
single_op_varargs_list (list): The list to append the variable arguments to.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
None
|
|
451
|
+
"""
|
|
452
|
+
for func_proto in func_protos:
|
|
453
|
+
varargs = func_proto.varargs
|
|
454
|
+
args = func_proto.op_proto.op_args
|
|
455
|
+
op_name = func_proto.op_proto.op_class.name
|
|
456
|
+
if varargs:
|
|
457
|
+
if len(varargs) != 1:
|
|
458
|
+
raise ValueError(
|
|
459
|
+
f'There must be only one variable argument. But got {len(vararg_index)} in {op_name}')
|
|
460
|
+
vararg_index = [i for i in range(len(args)) if args[i].arg_name == varargs[0]]
|
|
461
|
+
if len(vararg_index) != 1:
|
|
462
|
+
raise ValueError(
|
|
463
|
+
f'The variable arguments list of {op_name} is wrong, please check.')
|
|
464
|
+
single_op_varargs_list.append(
|
|
465
|
+
self.tensor_method_varargs_map_template.replace(op_name=op_name,
|
|
466
|
+
vararg_index=vararg_index[0])
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
def _get_tensor_varargs_map_list(self, tensor_method_protos_data):
|
|
470
|
+
"""
|
|
471
|
+
Generates a list of variable arguments for tensor methods.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
tensor_method_protos_data (dict): A dictionary of tensor method prototype data.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
list: A list of formatted strings representing the Variable arguments.
|
|
478
|
+
"""
|
|
479
|
+
tensor_method_varargs_list = []
|
|
480
|
+
for _, func_protos in tensor_method_protos_data.items():
|
|
481
|
+
self._get_and_append_single_op_varargs_list(func_protos,
|
|
482
|
+
tensor_method_varargs_list)
|
|
483
|
+
return tensor_method_varargs_list
|
|
484
|
+
|
|
485
|
+
def _get_mint_varargs_map_list(self, mint_func_protos_data, alias_func_mapping):
|
|
486
|
+
"""
|
|
487
|
+
Generates a list of variable arguments for mint functions.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
mint_func_protos_data (dict): A dictionary of mint function prototype data.
|
|
491
|
+
alias_func_mapping (dict): A dictionary mapping original function names to alias function names.
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
list: A list of formatted strings representing the variable arguments.
|
|
495
|
+
"""
|
|
496
|
+
mint_varargs_list = []
|
|
497
|
+
for func_api_name, func_protos in mint_func_protos_data.items():
|
|
498
|
+
self._get_and_append_single_op_varargs_list(func_protos,
|
|
499
|
+
mint_varargs_list)
|
|
500
|
+
|
|
501
|
+
if mint_varargs_list and func_api_name in alias_func_mapping:
|
|
502
|
+
self._get_and_append_single_op_varargs_list(func_protos,
|
|
503
|
+
mint_varargs_list)
|
|
504
|
+
return mint_varargs_list
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Module for generating C++ header files with operator name definitions.
|
|
17
|
+
|
|
18
|
+
This module defines the `OpsNameHGenerator` class, which produces C++ code to define
|
|
19
|
+
constants for operator names based on given prototypes.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import os
|
|
23
|
+
|
|
24
|
+
import common.gen_constants as K
|
|
25
|
+
import common.gen_utils as gen_utils
|
|
26
|
+
import common.template as template
|
|
27
|
+
from common.template import Template
|
|
28
|
+
|
|
29
|
+
from common.base_generator import BaseGenerator
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FunctionalOverloadPyGenerator(BaseGenerator):
|
|
33
|
+
"""
|
|
34
|
+
Class for generating C++ header files containing operator name constants.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self):
|
|
38
|
+
"""
|
|
39
|
+
Initializes the OpsNameHGenerator instance.
|
|
40
|
+
"""
|
|
41
|
+
self.FUNCTIONAL_OVERLOAD_PY_TEMPLATE = template.FUNCTIONAL_OVERLOAD_PY_TEMPLATE
|
|
42
|
+
|
|
43
|
+
self.mint_func_doc_yaml_dir_path = os.path.join(K.WORK_DIR, K.MS_MINT_FUNC_DOC_YAML_PATH)
|
|
44
|
+
self.import_mint_template = Template("from mindspore._c_expression import _${cpp_func_name}_instance\n")
|
|
45
|
+
self.mint_def_template = Template(
|
|
46
|
+
'def ${mint_func_name}(*args, **kwargs):\n'
|
|
47
|
+
' r"""\n${docstr}\n """\n'
|
|
48
|
+
' return _${cpp_func_name}_instance(*args, **kwargs)\n\n\n'
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def generate(self, work_path, mint_func_protos_data, alias_api_mapping):
|
|
52
|
+
"""
|
|
53
|
+
Generates python code for operator names and saves it to a header file.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
mint_func_protos_data (dict): A dictionary mapping mint API names to their prototype data.
|
|
57
|
+
function_doc_data (dict): A dictionary mapping function names to their docstring data.
|
|
58
|
+
alias_api_mapping (dict): A dictionary mapping aliases to their prototype data.
|
|
59
|
+
"""
|
|
60
|
+
function_doc_data = gen_utils.safe_load_yaml_from_dir(self.mint_func_doc_yaml_dir_path)
|
|
61
|
+
validate_func_docs(mint_func_protos_data, function_doc_data, alias_api_mapping)
|
|
62
|
+
import_mint_list, mint_init_list, mint_def_list, add_to_all_list = [], [], [], []
|
|
63
|
+
for mint_api_name, _ in mint_func_protos_data.items():
|
|
64
|
+
func_docstr = _format_docstring(function_doc_data[mint_api_name]["description"])
|
|
65
|
+
import_mint_list.append(self.import_mint_template.replace(cpp_func_name=mint_api_name))
|
|
66
|
+
mint_def_list.append(self.mint_def_template.replace(mint_func_name=mint_api_name,
|
|
67
|
+
docstr=func_docstr,
|
|
68
|
+
cpp_func_name=mint_api_name))
|
|
69
|
+
add_to_all_list.append(f'"{mint_api_name}",\n')
|
|
70
|
+
if mint_api_name in alias_api_mapping:
|
|
71
|
+
for alias_api_name in alias_api_mapping[mint_api_name]:
|
|
72
|
+
func_docstr = _format_docstring(function_doc_data[alias_api_name]["description"])
|
|
73
|
+
mint_def_list.append(self.mint_def_template.replace(mint_func_name=alias_api_name,
|
|
74
|
+
docstr=func_docstr,
|
|
75
|
+
cpp_func_name=mint_api_name))
|
|
76
|
+
add_to_all_list.append(f'"{alias_api_name}",\n')
|
|
77
|
+
|
|
78
|
+
func_overload_py_file = self.FUNCTIONAL_OVERLOAD_PY_TEMPLATE.replace(import_mint_list=import_mint_list,
|
|
79
|
+
mint_init_list=mint_init_list,
|
|
80
|
+
mint_def_list=mint_def_list,
|
|
81
|
+
add_to_all_list=add_to_all_list)
|
|
82
|
+
save_path = os.path.join(work_path, K.MS_MINT_FUNC_OVERLOAD_PATH)
|
|
83
|
+
file_name = "functional_overload.py"
|
|
84
|
+
gen_utils.save_file(save_path, file_name, func_overload_py_file)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _format_docstring(docstring, indent_size=4):
|
|
88
|
+
if docstring is None:
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
lines = docstring.split('\n')
|
|
92
|
+
# Add 4 spaces to each line except first line
|
|
93
|
+
formatted_lines = ([' ' * indent_size + lines[0]] +
|
|
94
|
+
[' ' * indent_size + line if line.strip() else line for line in lines[1:]])
|
|
95
|
+
return '\n'.join(formatted_lines)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def validate_func_docs(mint_func_protos_data, function_doc_data, alias_api_mapping):
|
|
99
|
+
"""
|
|
100
|
+
Ensure that the generated API includes corresponding docstrings; otherwise, raise an error to prompt the developer.
|
|
101
|
+
"""
|
|
102
|
+
mint_api_names = set(mint_func_protos_data.keys())
|
|
103
|
+
mint_doc_names = set(function_doc_data.keys())
|
|
104
|
+
all_api_names = set()
|
|
105
|
+
for mint_api_name in mint_api_names:
|
|
106
|
+
if mint_api_name in alias_api_mapping:
|
|
107
|
+
all_api_names = all_api_names.union(set(alias_api_mapping[mint_api_name]))
|
|
108
|
+
all_api_names = all_api_names.union(mint_api_names)
|
|
109
|
+
missing_docs = mint_doc_names - all_api_names
|
|
110
|
+
if missing_docs:
|
|
111
|
+
raise KeyError(f"Missing valid API references for the following doc names: {missing_docs}, "
|
|
112
|
+
f"please check if their doc.yaml files are defined in mindspore/ops/api_def/function_doc.")
|