mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2023-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -13,1087 +13,40 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""
|
|
16
|
-
|
|
16
|
+
Auto generate ops files.
|
|
17
17
|
"""
|
|
18
|
-
import os
|
|
19
|
-
import re
|
|
20
|
-
import shutil
|
|
21
|
-
import pathlib
|
|
22
18
|
import logging
|
|
23
|
-
import gen_utils
|
|
24
|
-
from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files,
|
|
25
|
-
merge_files_append, safe_load_yaml, convert_dtype_str, write_file)
|
|
26
|
-
from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes
|
|
27
|
-
import template
|
|
28
|
-
from template import CppTemplate
|
|
29
|
-
from gen_pyboost_func import gen_pyboost_code
|
|
30
|
-
from gen_aclnn_implement import gen_aclnn_kernel
|
|
31
|
-
import gen_constants as K
|
|
32
19
|
|
|
20
|
+
from resources.resource_manager import prepare_resources
|
|
21
|
+
from common import gen_utils
|
|
33
22
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class_def = yaml_value.get("class")
|
|
40
|
-
if class_def is not None:
|
|
41
|
-
class_name_specify = class_def.get("name")
|
|
42
|
-
if class_name_specify is not None:
|
|
43
|
-
return class_name_specify
|
|
44
|
-
# Else use the default rule generate class name.
|
|
45
|
-
op_name = yaml_key
|
|
46
|
-
class_name_normal = ''.join(word.capitalize() for word in op_name.split('_'))
|
|
47
|
-
return class_name_normal
|
|
23
|
+
from op_def.gen_op_def import generate_ops_def_files
|
|
24
|
+
from op_def_py.gen_op_def_py import generate_ops_py_files
|
|
25
|
+
from api.gen_api import generate_api_files
|
|
26
|
+
from aclnn.aclnn_kernel_register_auto_cc_generator import generate_aclnn_reg_file
|
|
27
|
+
from pyboost.gen_pyboost_func import gen_pyboost_code
|
|
48
28
|
|
|
49
29
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
func_name = item
|
|
58
|
-
return func_name
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _auto_generate_class_disabled(yaml_value):
|
|
62
|
-
"""Check whether class can be auto generated."""
|
|
63
|
-
if 'class' not in yaml_value.keys():
|
|
64
|
-
return False
|
|
65
|
-
class_def = yaml_value.get("class")
|
|
66
|
-
if 'disable' not in class_def.keys():
|
|
67
|
-
return False
|
|
68
|
-
disable_item = class_def.get("disable")
|
|
69
|
-
if disable_item is True:
|
|
70
|
-
return True
|
|
71
|
-
if disable_item is False:
|
|
72
|
-
return False
|
|
73
|
-
raise TypeError(f"The disable label for class should be True or False, but get {disable_item}.")
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _auto_generate_func_disabled(yaml_value):
|
|
77
|
-
"""Check whether function can be auto generated."""
|
|
78
|
-
if 'function' not in yaml_value.keys():
|
|
79
|
-
return False
|
|
80
|
-
func_def = yaml_value.get('function')
|
|
81
|
-
if 'disable' not in func_def.keys():
|
|
82
|
-
return False
|
|
83
|
-
disable_item = func_def.get("disable")
|
|
84
|
-
if disable_item is True:
|
|
85
|
-
return True
|
|
86
|
-
if disable_item is False:
|
|
87
|
-
return False
|
|
88
|
-
raise TypeError(f"The disable label for function should be True or False, but get {disable_item}.")
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
|
|
92
|
-
"""
|
|
93
|
-
Generate signature rw code
|
|
94
|
-
"""
|
|
95
|
-
for rw_arg_name in write_list:
|
|
96
|
-
if rw_arg_name == arg_name:
|
|
97
|
-
return ', sig.sig_rw.RW_WRITE'
|
|
98
|
-
for read_arg_name in read_list:
|
|
99
|
-
if read_arg_name == arg_name:
|
|
100
|
-
return ', sig.sig_rw.RW_READ'
|
|
101
|
-
for ref_arg_name in ref_list:
|
|
102
|
-
if ref_arg_name == arg_name:
|
|
103
|
-
return ', sig.sig_rw.RW_REF'
|
|
104
|
-
return ''
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def signature_get_rw_label_cc(rw_op_name, write_list, read_list, ref_list):
|
|
108
|
-
"""
|
|
109
|
-
Generate cc signature rw code
|
|
110
|
-
"""
|
|
111
|
-
rw_label = 'kRWDefault'
|
|
112
|
-
for op in write_list:
|
|
113
|
-
if op == rw_op_name:
|
|
114
|
-
rw_label = 'kRWWrite'
|
|
115
|
-
for op in read_list:
|
|
116
|
-
if op == rw_op_name:
|
|
117
|
-
rw_label = 'kRWRead'
|
|
118
|
-
for op in ref_list:
|
|
119
|
-
if op == rw_op_name:
|
|
120
|
-
rw_label = 'kRWRef'
|
|
121
|
-
return 'SignatureEnumRW::' + rw_label
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def signature_get_enum_dtype_cc(index):
|
|
125
|
-
"""
|
|
126
|
-
Generate cc enum dtype code
|
|
127
|
-
"""
|
|
128
|
-
enum_type = 'SignatureEnumDType::'
|
|
129
|
-
type_map = {0: 'kDType',
|
|
130
|
-
1: 'kDType1',
|
|
131
|
-
2: 'kDType2',
|
|
132
|
-
3: 'kDType3',
|
|
133
|
-
4: 'kDType4',
|
|
134
|
-
5: 'kDType5',
|
|
135
|
-
6: 'kDType6',
|
|
136
|
-
7: 'kDType7',
|
|
137
|
-
8: 'kDType8',
|
|
138
|
-
9: 'kDType9'}
|
|
139
|
-
if index in type_map:
|
|
140
|
-
return enum_type + type_map[index]
|
|
141
|
-
return enum_type + 'kDTypeEmptyDefaultValue'
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def signature_get_dtype_label(index):
|
|
145
|
-
"""
|
|
146
|
-
Generate signature dtype code
|
|
147
|
-
"""
|
|
148
|
-
dtype_index = ''
|
|
149
|
-
if index > 0:
|
|
150
|
-
dtype_index = f"""{index}"""
|
|
151
|
-
return f"""dtype=sig.sig_dtype.T{dtype_index}"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def get_same_dtype_groups(args_signature, args_name):
|
|
155
|
-
"""
|
|
156
|
-
Get same dtype groups
|
|
157
|
-
"""
|
|
158
|
-
same_dtype_groups = {}
|
|
159
|
-
dtype_conut = 0
|
|
160
|
-
if args_signature is None:
|
|
161
|
-
return same_dtype_groups, dtype_conut
|
|
162
|
-
|
|
163
|
-
dtype_group = args_signature.get('dtype_group')
|
|
164
|
-
if dtype_group is not None:
|
|
165
|
-
args_list = []
|
|
166
|
-
match = re.findall(r'\((.*?)\)', dtype_group)
|
|
167
|
-
for item in match:
|
|
168
|
-
args_list.append(item.replace(' ', '').split(","))
|
|
169
|
-
for arg_name in args_name:
|
|
170
|
-
if arg_name in same_dtype_groups:
|
|
171
|
-
continue
|
|
172
|
-
is_match = False
|
|
173
|
-
for group in args_list:
|
|
174
|
-
if arg_name in group:
|
|
175
|
-
is_match = True
|
|
176
|
-
for item in group:
|
|
177
|
-
same_dtype_groups[item] = dtype_conut
|
|
178
|
-
break
|
|
179
|
-
if not is_match:
|
|
180
|
-
same_dtype_groups[arg_name] = dtype_conut
|
|
181
|
-
dtype_conut = dtype_conut + 1
|
|
182
|
-
return same_dtype_groups, dtype_conut
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
def generate_py_op_signature(op_name, args_signature, args_name, args_default):
|
|
186
|
-
"""
|
|
187
|
-
Generate __mindspore_signature__
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
|
|
191
|
-
for sig_arg_name in sig_arg_names:
|
|
192
|
-
if sig_arg_name not in args_names:
|
|
193
|
-
raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
|
|
194
|
-
|
|
195
|
-
if args_signature is None and not args_default:
|
|
196
|
-
return ''
|
|
197
|
-
|
|
198
|
-
signature_code = f""" __mindspore_signature__ = """
|
|
199
|
-
|
|
200
|
-
# Init rw.
|
|
201
|
-
write_list = []
|
|
202
|
-
read_list = []
|
|
203
|
-
ref_list = []
|
|
204
|
-
if args_signature is not None:
|
|
205
|
-
rw_write = args_signature.get('rw_write')
|
|
206
|
-
rw_read = args_signature.get('rw_read')
|
|
207
|
-
rw_ref = args_signature.get('rw_ref')
|
|
208
|
-
if rw_write is not None:
|
|
209
|
-
write_list = rw_write.replace(' ', '').split(",")
|
|
210
|
-
_check_signature_arg_valid(op_name, write_list, args_name)
|
|
211
|
-
if rw_read is not None:
|
|
212
|
-
read_list = rw_read.replace(' ', '').split(",")
|
|
213
|
-
_check_signature_arg_valid(op_name, read_list, args_name)
|
|
214
|
-
if rw_ref is not None:
|
|
215
|
-
ref_list = rw_ref.replace(' ', '').split(",")
|
|
216
|
-
_check_signature_arg_valid(op_name, ref_list, args_name)
|
|
217
|
-
# Init dtype group.
|
|
218
|
-
same_dtype_groups, dtype_conut = get_same_dtype_groups(args_signature, args_name)
|
|
219
|
-
_check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
|
|
220
|
-
# Only one dtype_group is set.
|
|
221
|
-
if dtype_conut == 1 and not any([write_list, read_list, ref_list, args_default]):
|
|
222
|
-
signature_code += '('
|
|
223
|
-
for _ in range(len(args_name) - 1):
|
|
224
|
-
signature_code += 'sig.sig_dtype.T, '
|
|
225
|
-
signature_code += 'sig.sig_dtype.T)\n\n'
|
|
226
|
-
return signature_code
|
|
227
|
-
|
|
228
|
-
# Set sig.make_sig.
|
|
229
|
-
signature_code += f""" (\n"""
|
|
230
|
-
for arg_name in args_name:
|
|
231
|
-
signature_code += f""" sig.make_sig('{arg_name}'"""
|
|
232
|
-
signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
|
|
233
|
-
if arg_name in same_dtype_groups:
|
|
234
|
-
signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
|
|
235
|
-
if arg_name in args_default:
|
|
236
|
-
signature_code += f""", default=""" + str(args_default[arg_name])
|
|
237
|
-
signature_code += f"""),\n"""
|
|
238
|
-
signature_code += f""" )\n\n"""
|
|
239
|
-
return signature_code
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
def generate_cc_op_signature(args_signature, args_name):
|
|
243
|
-
"""
|
|
244
|
-
generate signatures on in cc file
|
|
245
|
-
:param args_signature:
|
|
246
|
-
:param args_name:
|
|
247
|
-
:return:
|
|
248
|
-
"""
|
|
249
|
-
if args_signature is None:
|
|
250
|
-
return ''
|
|
251
|
-
signature_code = ''
|
|
252
|
-
# Init rw.
|
|
253
|
-
write_list = []
|
|
254
|
-
read_list = []
|
|
255
|
-
ref_list = []
|
|
256
|
-
if args_signature is not None:
|
|
257
|
-
rw_write = args_signature.get('rw_write')
|
|
258
|
-
rw_read = args_signature.get('rw_read')
|
|
259
|
-
rw_ref = args_signature.get('rw_ref')
|
|
260
|
-
if rw_write is not None:
|
|
261
|
-
write_list = rw_write.replace(' ', '').split(",")
|
|
262
|
-
if rw_read is not None:
|
|
263
|
-
read_list = rw_read.replace(' ', '').split(",")
|
|
264
|
-
if rw_ref is not None:
|
|
265
|
-
ref_list = rw_ref.replace(' ', '').split(",")
|
|
266
|
-
# Init dtype group.
|
|
267
|
-
same_dtype_groups, _ = get_same_dtype_groups(args_signature, args_name)
|
|
268
|
-
for arg_name in args_name:
|
|
269
|
-
enum_rw = signature_get_rw_label_cc(arg_name, write_list, read_list, ref_list)
|
|
270
|
-
enum_dtype = signature_get_enum_dtype_cc(same_dtype_groups.get(arg_name))
|
|
271
|
-
signature = f"""Signature("{arg_name}", {enum_rw}, \
|
|
272
|
-
SignatureEnumKind::kKindPositionalKeyword, nullptr, {enum_dtype}),\n """
|
|
273
|
-
signature_code += signature
|
|
274
|
-
return signature_code
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
def generate_py_op_deprecated(deprecated):
|
|
278
|
-
"""
|
|
279
|
-
Generate @deprecated
|
|
280
|
-
"""
|
|
281
|
-
if deprecated is None:
|
|
282
|
-
return ''
|
|
283
|
-
version = deprecated.get("version")
|
|
284
|
-
if version is None:
|
|
285
|
-
raise ValueError("The version of deprecated can't be None.")
|
|
286
|
-
substitute = deprecated.get("substitute")
|
|
287
|
-
if substitute is None:
|
|
288
|
-
raise ValueError("The substitute of deprecated can't be None.")
|
|
289
|
-
use_substitute = deprecated.get("use_substitute")
|
|
290
|
-
if use_substitute is None:
|
|
291
|
-
raise ValueError("The use_substitute of deprecated can't be None.")
|
|
292
|
-
if use_substitute is not True and use_substitute is not False:
|
|
293
|
-
raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
|
|
294
|
-
|
|
295
|
-
deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
|
|
296
|
-
return deprecated
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
def _normalize_func_description_fromat(description):
|
|
300
|
-
"""
|
|
301
|
-
Process description.
|
|
302
|
-
"""
|
|
303
|
-
if not description:
|
|
304
|
-
return description
|
|
305
|
-
lines = description.split("\n")
|
|
306
|
-
if len(lines) == 1:
|
|
307
|
-
return description
|
|
308
|
-
# Add line indentation to other lines after the first line
|
|
309
|
-
for i in range(1, len(lines)):
|
|
310
|
-
indent = " " if lines[i] else ""
|
|
311
|
-
lines[i] = indent + lines[i]
|
|
312
|
-
# Remove trailing blank lines
|
|
313
|
-
lines = lines if lines[-1] != "" else lines[:-1]
|
|
314
|
-
description = "\n".join(lines)
|
|
315
|
-
return description
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
def _get_op_description(operator_name, doc_str):
|
|
319
|
-
"""
|
|
320
|
-
Generate ops api description.
|
|
321
|
-
"""
|
|
322
|
-
if doc_str is None:
|
|
323
|
-
print(f"Description is None, op_name: {operator_name}")
|
|
324
|
-
return ""
|
|
325
|
-
description = doc_str.get(operator_name)
|
|
326
|
-
if description is None:
|
|
327
|
-
print(f"Description is None, op_name: {operator_name}")
|
|
328
|
-
return ""
|
|
329
|
-
description = description.get("description")
|
|
330
|
-
if description is None:
|
|
331
|
-
print(f"Description is None, op_name: {operator_name}")
|
|
332
|
-
return ""
|
|
333
|
-
return _normalize_func_description_fromat(description)
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
def generate_py_op_func(yaml_data, doc_data):
|
|
337
|
-
"""
|
|
338
|
-
Generate operator python function api.
|
|
339
|
-
"""
|
|
340
|
-
gen_py = ''
|
|
341
|
-
|
|
342
|
-
for operator_name, operator_data in yaml_data.items():
|
|
343
|
-
if _auto_generate_func_disabled(operator_data):
|
|
344
|
-
continue
|
|
345
|
-
func_name = _get_op_func_name(operator_name, operator_data)
|
|
346
|
-
args = operator_data.get('args')
|
|
347
|
-
class_name = _get_op_name(operator_name, operator_data)
|
|
348
|
-
func_args = []
|
|
349
|
-
prim_init_args = []
|
|
350
|
-
prim_call_args = []
|
|
351
|
-
for arg_name, arg_info in args.items():
|
|
352
|
-
is_prim_init = arg_info.get('prim_init')
|
|
353
|
-
has_default = 'default' in arg_info.keys()
|
|
354
|
-
|
|
355
|
-
# step1: Process function args.
|
|
356
|
-
if not has_default:
|
|
357
|
-
func_args.append(f"""{arg_name}""")
|
|
358
|
-
else:
|
|
359
|
-
default_value = arg_info.get('default')
|
|
360
|
-
func_args.append(f"""{arg_name}={default_value}""")
|
|
361
|
-
|
|
362
|
-
# step2: Process primitive object init args.
|
|
363
|
-
if is_prim_init:
|
|
364
|
-
prim_init_args.append(arg_name)
|
|
365
|
-
|
|
366
|
-
# step3: Process primitive object call args.
|
|
367
|
-
else:
|
|
368
|
-
prim_call_args.append(arg_name)
|
|
369
|
-
description = _get_op_description(operator_name, doc_data)
|
|
370
|
-
function_code = f"""\n
|
|
371
|
-
def {func_name}({', '.join(arg for arg in func_args)}):
|
|
372
|
-
r\"\"\"
|
|
373
|
-
{description}
|
|
374
|
-
\"\"\"
|
|
375
|
-
{operator_name}_op = _get_cache_prim({class_name})({', '.join(arg_name for arg_name in prim_init_args)})
|
|
376
|
-
return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
|
|
377
|
-
|
|
378
|
-
if not prim_init_args:
|
|
379
|
-
if _auto_generate_class_disabled(operator_data):
|
|
380
|
-
gen_py += f"""\n{operator_name}_op={class_name}()"""
|
|
381
|
-
function_code = f"""\n
|
|
382
|
-
def {func_name}({', '.join(arg for arg in func_args)}):
|
|
383
|
-
r\"\"\"
|
|
384
|
-
{description}
|
|
385
|
-
\"\"\"
|
|
386
|
-
return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
|
|
387
|
-
else:
|
|
388
|
-
dis = operator_data.get("dispatch")
|
|
389
|
-
if dis is not None:
|
|
390
|
-
enable_pyboost = dis.get("enable")
|
|
391
|
-
if enable_pyboost:
|
|
392
|
-
function_code = f"""\n
|
|
393
|
-
def {func_name}({', '.join(arg for arg in func_args)}):
|
|
394
|
-
r\"\"\"
|
|
395
|
-
{description}
|
|
396
|
-
\"\"\"
|
|
397
|
-
return {operator_name}_impl({', '.join(arg_name for arg_name, _ in args.items())})\n"""
|
|
398
|
-
gen_py += function_code
|
|
399
|
-
|
|
400
|
-
return gen_py
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
def get_dtype(arg_info):
|
|
404
|
-
dtype = arg_info.get('dtype')
|
|
405
|
-
# Currently, TypeId is represented by int
|
|
406
|
-
if dtype == 'TypeId':
|
|
407
|
-
dtype = 'int'
|
|
408
|
-
return dtype
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
def process_args(class_name, args):
|
|
412
|
-
"""
|
|
413
|
-
Process arg for yaml, get arg_name, init value, type cast, arg_handler, etc.
|
|
414
|
-
"""
|
|
415
|
-
inputs_name = []
|
|
416
|
-
args_name = []
|
|
417
|
-
args_assign = []
|
|
418
|
-
inputs_default = {}
|
|
419
|
-
init_args_with_default = []
|
|
420
|
-
args_handlers = {}
|
|
421
|
-
for arg_name, arg_info in args.items():
|
|
422
|
-
dtype = get_dtype(arg_info)
|
|
423
|
-
default_value = arg_info.get('default')
|
|
424
|
-
has_default = 'default' in arg_info.keys()
|
|
425
|
-
is_prim_init = arg_info.get('prim_init')
|
|
426
|
-
arg_handler = arg_info.get('arg_handler')
|
|
427
|
-
|
|
428
|
-
# step1: get args infos:
|
|
429
|
-
if is_prim_init:
|
|
430
|
-
# step1.1: get args name:
|
|
431
|
-
args_name.append(arg_name)
|
|
432
|
-
# step1.2: get args assign with default value:
|
|
433
|
-
if has_default:
|
|
434
|
-
init_args_with_default.append(f"""{arg_name}={default_value}""")
|
|
435
|
-
else:
|
|
436
|
-
init_args_with_default.append(f"""{arg_name}""")
|
|
437
|
-
|
|
438
|
-
# step1.3: get args set prim arg expression:
|
|
439
|
-
assign_str = gen_utils.get_assign_str_by_type_it(class_name, arg_info, arg_name, dtype)
|
|
440
|
-
if arg_handler:
|
|
441
|
-
assign_str = f""" self._set_prim_arg_with_handler("{arg_name}", {assign_str}, {arg_handler})"""
|
|
442
|
-
else:
|
|
443
|
-
assign_str = f""" self._set_prim_arg("{arg_name}", {assign_str})"""
|
|
444
|
-
args_assign.append(assign_str)
|
|
445
|
-
# step2: get inputs infos:
|
|
446
|
-
else:
|
|
447
|
-
# step2.1: get inputs name:
|
|
448
|
-
inputs_name.append(arg_name)
|
|
449
|
-
|
|
450
|
-
# step2.2: get default value of inputs:
|
|
451
|
-
if has_default:
|
|
452
|
-
inputs_default[arg_name] = default_value
|
|
453
|
-
|
|
454
|
-
# step2.3: get args_handler functions for inputs
|
|
455
|
-
if arg_handler:
|
|
456
|
-
args_handlers[arg_name] = arg_handler
|
|
457
|
-
|
|
458
|
-
return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
def generate_pyboost_import_header(yaml_data):
|
|
462
|
-
"""
|
|
463
|
-
Generate python primitive
|
|
464
|
-
"""
|
|
465
|
-
pyboost_import_header = ''
|
|
466
|
-
import_pyboost = CppTemplate("from mindspore._c_expression import $var\n")
|
|
467
|
-
for operator_name, operator_data in yaml_data.items():
|
|
468
|
-
is_pyboost = is_pyboost_enable(operator_data)
|
|
469
|
-
if is_pyboost:
|
|
470
|
-
header = import_pyboost.replace(var=get_pyboost_name(operator_name))
|
|
471
|
-
pyboost_import_header += header
|
|
472
|
-
return pyboost_import_header
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
def _generate_class_description(class_name, func_name, input_args, init_args, func_disabled, doc_str):
|
|
476
|
-
"""Generate description for every primitive definition."""
|
|
477
|
-
if func_disabled:
|
|
478
|
-
# if function disabled, function name is equal to operator_name
|
|
479
|
-
description = _get_op_description(func_name, doc_str)
|
|
480
|
-
description = f""" r\"\"\"
|
|
481
|
-
{description}
|
|
482
|
-
\"\"\"
|
|
483
|
-
"""
|
|
484
|
-
return description
|
|
485
|
-
|
|
486
|
-
# If function is an released API, refer to the function doc.
|
|
487
|
-
description_str = f""" r\"\"\"
|
|
488
|
-
.. code-block::
|
|
489
|
-
|
|
490
|
-
prim = ops.{class_name}({', '.join(init_args)})
|
|
491
|
-
out = prim({', '.join(input_args)})
|
|
492
|
-
|
|
493
|
-
is equivalent to
|
|
494
|
-
|
|
495
|
-
.. code-block::
|
|
496
|
-
|
|
497
|
-
ops.{func_name}({", ".join(input_args + init_args)})
|
|
498
|
-
|
|
499
|
-
Refer to :func:`mindspore.ops.{func_name}` for more details.
|
|
500
|
-
\"\"\"
|
|
501
|
-
"""
|
|
502
|
-
return description_str
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
def get_init_code(init_code, operator_data):
|
|
506
|
-
"""
|
|
507
|
-
Generate init code for primitive
|
|
508
|
-
"""
|
|
509
|
-
labels = operator_data.get('labels')
|
|
510
|
-
if labels is not None:
|
|
511
|
-
if init_code != "":
|
|
512
|
-
init_code += "\n"
|
|
513
|
-
init_code += \
|
|
514
|
-
'\n'.join([f""" self.add_prim_attr("{key}", {value})""" for key, value in labels.items()])
|
|
515
|
-
if init_code == "":
|
|
516
|
-
init_code = f""" pass"""
|
|
517
|
-
return init_code
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
def generate_py_primitive(yaml_data, doc_str):
|
|
521
|
-
"""
|
|
522
|
-
Generate python primitive
|
|
523
|
-
"""
|
|
524
|
-
|
|
525
|
-
def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
|
|
526
|
-
"""Generate arg_handler"""
|
|
527
|
-
arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
|
|
528
|
-
if is_optional:
|
|
529
|
-
arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
|
|
530
|
-
return arg_handler_call
|
|
531
|
-
|
|
532
|
-
gen_py = ''
|
|
533
|
-
for operator_name, operator_data in yaml_data.items():
|
|
534
|
-
if _auto_generate_class_disabled(operator_data):
|
|
535
|
-
continue
|
|
536
|
-
class_name = _get_op_name(operator_name, operator_data)
|
|
537
|
-
func_name = _get_op_func_name(operator_name, operator_data)
|
|
538
|
-
pyboost_func_name = get_pyboost_name(operator_name)
|
|
539
|
-
args = operator_data.get('args')
|
|
540
|
-
inputs_args, inputs_default, init_args, args_assign, init_args_with_default, args_handlers = \
|
|
541
|
-
process_args(class_name, args)
|
|
542
|
-
init_code = '\n'.join(args_assign)
|
|
543
|
-
signature_code = generate_py_op_signature(class_name, operator_data.get('args_signature'), inputs_args,
|
|
544
|
-
inputs_default)
|
|
545
|
-
deprecated_code = generate_py_op_deprecated(operator_data.get('deprecated'))
|
|
546
|
-
init_code = get_init_code(init_code, operator_data)
|
|
547
|
-
primitive_code = f"""\n
|
|
548
|
-
class {class_name}(Primitive):\n"""
|
|
549
|
-
func_disabled = _auto_generate_func_disabled(operator_data)
|
|
550
|
-
primitive_code += _generate_class_description(class_name, func_name, inputs_args, init_args, func_disabled,
|
|
551
|
-
doc_str)
|
|
552
|
-
if signature_code != "":
|
|
553
|
-
primitive_code += signature_code
|
|
554
|
-
if deprecated_code != "":
|
|
555
|
-
primitive_code += deprecated_code
|
|
556
|
-
primitive_code += f""" @prim_arg_register
|
|
557
|
-
def __init__(self"""
|
|
558
|
-
if init_args_with_default:
|
|
559
|
-
primitive_code += ", " + f"""{', '.join(init_args_with_default) if init_args_with_default else ''}"""
|
|
560
|
-
call_args = []
|
|
561
|
-
for name in inputs_args:
|
|
562
|
-
call_args.append(f"""{name}={inputs_default[name]}""" if name in inputs_default else name)
|
|
563
|
-
primitive_code += f"""):
|
|
564
|
-
{init_code}
|
|
565
|
-
|
|
566
|
-
def __call__(self, {', '.join(call_args)}):"""
|
|
567
|
-
is_pyboost = is_pyboost_enable(operator_data)
|
|
568
|
-
if is_pyboost:
|
|
569
|
-
primitive_code += f"""
|
|
570
|
-
return _convert_stub({pyboost_func_name}(self, ["""
|
|
571
|
-
else:
|
|
572
|
-
primitive_code += f"""
|
|
573
|
-
return super().__call__("""
|
|
574
|
-
if inputs_args:
|
|
575
|
-
args_with_handler = []
|
|
576
|
-
for arg in inputs_args:
|
|
577
|
-
if arg in args_handlers:
|
|
578
|
-
is_optional = inputs_default.get(arg) == "None"
|
|
579
|
-
args_with_handler.append(_generate_arg_handler(class_name, arg, args_handlers[arg], is_optional))
|
|
580
|
-
else:
|
|
581
|
-
args_with_handler.append(arg)
|
|
582
|
-
primitive_code += ', '.join(args_with_handler)
|
|
583
|
-
|
|
584
|
-
if init_args:
|
|
585
|
-
primitive_code += ', '
|
|
586
|
-
primitive_code += ', '.join([f'self.{arg}' for arg in init_args])
|
|
587
|
-
if is_pyboost:
|
|
588
|
-
primitive_code += """]))"""
|
|
589
|
-
else:
|
|
590
|
-
primitive_code += """)
|
|
591
|
-
"""
|
|
592
|
-
|
|
593
|
-
gen_py += primitive_code
|
|
594
|
-
if not init_args:
|
|
595
|
-
prim_op_object = f"""\n
|
|
596
|
-
{operator_name}_op={class_name}()
|
|
597
|
-
"""
|
|
598
|
-
gen_py += prim_op_object
|
|
599
|
-
return gen_py
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
def generate_op_name_opdef(yaml_data):
|
|
603
|
-
"""
|
|
604
|
-
Generate op name
|
|
605
|
-
"""
|
|
606
|
-
op_name_head = f"""
|
|
607
|
-
#ifndef MINDSPORE_CORE_OP_NAME_H_
|
|
608
|
-
#define MINDSPORE_CORE_OP_NAME_H_
|
|
609
|
-
|
|
610
|
-
namespace mindspore::ops {{
|
|
611
|
-
"""
|
|
612
|
-
|
|
613
|
-
op_name_end = f"""}} // namespace mindspore::ops
|
|
614
|
-
|
|
615
|
-
#endif // MINDSPORE_CORE_OP_NAME_H_
|
|
616
|
-
"""
|
|
617
|
-
|
|
618
|
-
op_name_gen = ''
|
|
619
|
-
op_name_gen += op_name_head
|
|
620
|
-
for operator_name, operator_data in yaml_data.items():
|
|
621
|
-
k_name_op = _get_op_name(operator_name, operator_data)
|
|
622
|
-
op_name_gen += f"""constexpr auto kName{k_name_op} = "{k_name_op}";
|
|
623
|
-
"""
|
|
624
|
-
|
|
625
|
-
op_name_gen += op_name_end
|
|
626
|
-
return op_name_gen
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
def generate_op_prim_opdef(yaml_data):
|
|
630
|
-
"""
|
|
631
|
-
Generate primitive c++ definition
|
|
632
|
-
"""
|
|
633
|
-
ops_prim_head = f"""
|
|
634
|
-
#ifndef MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
|
|
635
|
-
#define MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
|
|
636
|
-
|
|
637
|
-
#include <memory>
|
|
638
|
-
#include "ir/anf.h"
|
|
639
|
-
#include "ir/primitive.h"
|
|
640
|
-
#include "{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
|
|
641
|
-
#include "mindapi/base/macros.h"
|
|
642
|
-
|
|
643
|
-
namespace mindspore::prim {{
|
|
644
|
-
"""
|
|
645
|
-
|
|
646
|
-
ops_prim_end = f"""}} // namespace mindspore::prim
|
|
647
|
-
#endif // MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
|
|
648
|
-
"""
|
|
649
|
-
|
|
650
|
-
ops_prim_gen = ''
|
|
651
|
-
ops_prim_gen += ops_prim_head
|
|
652
|
-
for operator_name, operator_data in yaml_data.items():
|
|
653
|
-
k_name_op = _get_op_name(operator_name, operator_data)
|
|
654
|
-
ops_prim_gen += f"""GVAR_DEF(PrimitivePtr, kPrim{k_name_op}, std::make_shared<Primitive>(ops::kName{k_name_op}))
|
|
655
|
-
"""
|
|
656
|
-
ops_prim_gen += ops_prim_end
|
|
657
|
-
return ops_prim_gen
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
def generate_lite_ops(yaml_data):
|
|
661
|
-
"""
|
|
662
|
-
Generate BaseOperator parameter set and get func
|
|
663
|
-
"""
|
|
664
|
-
lite_ops_h_head = f"""
|
|
665
|
-
#ifndef MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
|
|
666
|
-
#define MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
|
|
667
|
-
|
|
668
|
-
#include <vector>
|
|
669
|
-
#include "ops/base_operator.h"
|
|
670
|
-
#include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_ops_name.h"
|
|
671
|
-
|
|
672
|
-
namespace mindspore::ops {{
|
|
673
|
-
"""
|
|
674
|
-
|
|
675
|
-
lite_ops_h_end = f"""}} // namespace mindspore::ops
|
|
676
|
-
#endif // MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
|
|
677
|
-
"""
|
|
678
|
-
|
|
679
|
-
lite_ops_cc_head = f"""
|
|
680
|
-
#include "{K.OP_DEF_AUTO_GENERATE_PATH}/gen_lite_ops.h"
|
|
681
|
-
#include "mindapi/helper.h"
|
|
682
|
-
#include "ops/primitive_c.h"
|
|
683
|
-
#include "ops/base_operator.h"
|
|
684
|
-
#include "abstract/abstract_value.h"
|
|
685
|
-
|
|
686
|
-
namespace mindspore::ops {{
|
|
687
|
-
"""
|
|
688
|
-
|
|
689
|
-
lite_ops_cc_end = f"""}} // namespace mindspore::ops
|
|
690
|
-
"""
|
|
691
|
-
|
|
692
|
-
lite_ops_h_gen = ''
|
|
693
|
-
lite_ops_cc_gen = ''
|
|
694
|
-
|
|
695
|
-
lite_ops_h_gen += lite_ops_h_head
|
|
696
|
-
lite_ops_cc_gen += lite_ops_cc_head
|
|
697
|
-
for operator_name, operator_data in yaml_data.items():
|
|
698
|
-
op_name = _get_op_name(operator_name, operator_data)
|
|
699
|
-
lite_ops_h_gen += f"""class OPS_API {op_name} : public BaseOperator {{
|
|
700
|
-
public:
|
|
701
|
-
MIND_API_BASE_MEMBER({op_name});
|
|
702
|
-
{op_name}() : BaseOperator(kName{op_name}) {{}}\n"""
|
|
703
|
-
args = operator_data.get('args')
|
|
704
|
-
for _, (arg_name, arg_info) in enumerate(args.items()):
|
|
705
|
-
is_prim_init = arg_info.get('prim_init')
|
|
706
|
-
if not is_prim_init:
|
|
707
|
-
continue
|
|
708
|
-
|
|
709
|
-
dtype = get_dtype(arg_info)
|
|
710
|
-
if dtype == "str":
|
|
711
|
-
dtype = "std::string"
|
|
712
|
-
if dtype in ("tuple[str]", "list[str]"):
|
|
713
|
-
dtype = "std::vector<std::string>"
|
|
714
|
-
if dtype in ("tuple[int]", "list[int]"):
|
|
715
|
-
dtype = "std::vector<int64_t>"
|
|
716
|
-
if dtype in ("tuple[float]", "list[float]"):
|
|
717
|
-
dtype = "std::vector<float>"
|
|
718
|
-
if dtype in ("tuple[bool]", "list[bool]"):
|
|
719
|
-
dtype = "std::vector<bool>"
|
|
720
|
-
if dtype == "int":
|
|
721
|
-
dtype = "int64_t"
|
|
722
|
-
lite_ops_h_gen += f""" void set_{arg_name}(const {dtype} &{arg_name});\n"""
|
|
723
|
-
lite_ops_h_gen += f""" {dtype} get_{arg_name}() const;\n"""
|
|
724
|
-
|
|
725
|
-
lite_ops_cc_gen += f"""void {op_name}::set_{arg_name}(const {dtype} &{arg_name}) \
|
|
726
|
-
{{ (void)this->AddAttr("{arg_name}", api::MakeValue({arg_name})); }}\n\n"""
|
|
727
|
-
lite_ops_cc_gen += f"""{dtype} {op_name}::get_{arg_name}() const \
|
|
728
|
-
{{ return GetValue<{dtype}>(GetAttr("{arg_name}")); }}\n\n"""
|
|
729
|
-
|
|
730
|
-
op_name = _get_op_name(operator_name, operator_data)
|
|
731
|
-
lite_ops_cc_gen += f"""REGISTER_PRIMITIVE_C(kName{op_name}, {op_name});\n"""
|
|
732
|
-
lite_ops_cc_gen += f"""MIND_API_OPERATOR_IMPL({op_name}, BaseOperator);\n\n"""
|
|
733
|
-
lite_ops_h_gen += f"""}};\n\n"""
|
|
734
|
-
lite_ops_h_gen += lite_ops_h_end
|
|
735
|
-
lite_ops_cc_gen += lite_ops_cc_end
|
|
736
|
-
return lite_ops_h_gen, lite_ops_cc_gen
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
def generate_cc_opdef(yaml_data):
|
|
740
|
-
"""
|
|
741
|
-
Generate c++ OpDef
|
|
742
|
-
"""
|
|
743
|
-
gen_cc_code = f"""\n
|
|
744
|
-
namespace mindspore::ops {{"""
|
|
745
|
-
gen_include = f"""\n
|
|
746
|
-
#include \"{K.MS_OP_DEF_AUTO_GENERATE_PATH}/gen_ops_def.h\""""
|
|
747
|
-
gen_include += f"""
|
|
748
|
-
#include \"ir/signature.h\""""
|
|
749
|
-
|
|
750
|
-
for operator_name, operator_data in yaml_data.items():
|
|
751
|
-
args = operator_data.get('args')
|
|
752
|
-
class_name = _get_op_name(operator_name, operator_data)
|
|
753
|
-
inputs_args, _, _, _, _, _ = process_args(class_name, args)
|
|
754
|
-
signature_code = generate_cc_op_signature(operator_data.get('args_signature'), inputs_args)
|
|
755
|
-
args = operator_data.get('args')
|
|
756
|
-
returns = operator_data.get('returns')
|
|
757
|
-
dispatch = operator_data.get("dispatch")
|
|
758
|
-
# dispatch not defined in yaml or dispatch.enable==False
|
|
759
|
-
if not dispatch or not dispatch.get("enable"):
|
|
760
|
-
dispatch = "false"
|
|
761
|
-
else:
|
|
762
|
-
dispatch = "true"
|
|
763
|
-
enable_dispatch_str = f"""{dispatch}"""
|
|
764
|
-
|
|
765
|
-
is_view = operator_data.get('view')
|
|
766
|
-
if is_view:
|
|
767
|
-
is_view_s = "true"
|
|
768
|
-
else:
|
|
769
|
-
is_view_s = "false"
|
|
770
|
-
is_view_str = f"""{is_view_s}"""
|
|
771
|
-
|
|
772
|
-
gen_include += f"""\n#include "{K.MS_OPS_FUNC_IMPL_PATH}/{operator_name}.h\""""
|
|
773
|
-
cc_index_str = ''
|
|
774
|
-
input_args_str = ''
|
|
775
|
-
args_dict = {}
|
|
776
|
-
for i, (arg_name, arg_info) in enumerate(args.items()):
|
|
777
|
-
args_dict[arg_name] = i
|
|
778
|
-
cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
|
|
779
|
-
dtype = get_dtype(arg_info)
|
|
780
|
-
cc_dtype_str = convert_dtype_str(dtype)
|
|
781
|
-
|
|
782
|
-
is_prim_init = 1 if arg_info.get('prim_init') else 0
|
|
783
|
-
arg_handler = arg_info.get('arg_handler')
|
|
784
|
-
arg_handler_str = "" if arg_handler is None else arg_handler
|
|
785
|
-
|
|
786
|
-
type_cast = arg_info.get('type_cast')
|
|
787
|
-
type_cast_str = "" if type_cast is None else \
|
|
788
|
-
', '.join('DT_' + type.replace('[', '_').replace(']', '').upper() for type in
|
|
789
|
-
(ct.strip() for ct in type_cast.split(",")))
|
|
790
|
-
|
|
791
|
-
# default: None is regarded as a optional argument.
|
|
792
|
-
is_optional_str = "false"
|
|
793
|
-
if 'default' in arg_info.keys() and arg_info.get('default') == "None":
|
|
794
|
-
is_optional_str = "true"
|
|
795
|
-
|
|
796
|
-
input_args_str += f"""\n {{/*.arg_name_=*/"{arg_name}", /*.arg_dtype_=*/{cc_dtype_str}, """ + \
|
|
797
|
-
f"""/*.as_init_arg_=*/{is_prim_init}, /*.arg_handler_=*/"{arg_handler_str}", """ + \
|
|
798
|
-
f"""/*.cast_dtype_ =*/{{{type_cast_str}}}, /*.is_optional_=*/{is_optional_str}}},"""
|
|
799
|
-
|
|
800
|
-
# Process outputs.
|
|
801
|
-
return_args_str = ''
|
|
802
|
-
for return_name, return_info in returns.items():
|
|
803
|
-
return_dtype = return_info.get('dtype')
|
|
804
|
-
ref_name = return_info.get('inplace')
|
|
805
|
-
ref_index_str = -1 if ref_name is None else args_dict.get(ref_name)
|
|
806
|
-
cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').upper()
|
|
807
|
-
return_args_str += f"""{{/*.arg_name_=*/"{return_name}", /*.arg_dtype_=*/{cc_return_type_str},
|
|
808
|
-
/*.inplace_input_index_=*/{ref_index_str}}},\n"""
|
|
809
|
-
|
|
810
|
-
op_def_cc = template.OP_PROTO_TEMPLATE.replace(class_name=class_name, input_args=input_args_str,
|
|
811
|
-
return_args=return_args_str, signatures=signature_code,
|
|
812
|
-
indexes=cc_index_str, enable_dispatch=enable_dispatch_str,
|
|
813
|
-
is_view=is_view_str)
|
|
814
|
-
gen_cc_code += op_def_cc
|
|
815
|
-
if is_view:
|
|
816
|
-
view_op_def = op_def_cc.replace(class_name, class_name+"View")
|
|
817
|
-
gen_cc_code += view_op_def
|
|
818
|
-
|
|
819
|
-
cc_opdef_end = f"""\n}} // namespace mindspore::ops\n"""
|
|
820
|
-
return gen_include + gen_cc_code + cc_opdef_end
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
ops_py_prim_header = f"""
|
|
824
|
-
\"\"\"Operators definition generated by gen_ops.py, includes primitive classes.\"\"\"
|
|
825
|
-
|
|
826
|
-
from mindspore.ops.primitive import Primitive, prim_arg_register
|
|
827
|
-
from mindspore.ops import signature as sig
|
|
828
|
-
from mindspore.common import dtype as mstype
|
|
829
|
-
from mindspore.common._decorator import deprecated
|
|
830
|
-
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
831
|
-
from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it
|
|
832
|
-
from mindspore.ops.auto_generate.gen_arg_handler import *
|
|
833
|
-
from mindspore._c_expression import OpDtype
|
|
834
|
-
from mindspore.common._stub_tensor import _convert_stub
|
|
835
|
-
"""
|
|
836
|
-
|
|
837
|
-
ops_py_def_header = f"""
|
|
838
|
-
\"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\"
|
|
839
|
-
|
|
840
|
-
from .gen_ops_prim import *
|
|
841
|
-
from .pyboost_inner_prim import *
|
|
842
|
-
from mindspore.ops.operations.manually_defined.ops_def import *
|
|
843
|
-
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
844
|
-
"""
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
def generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre):
|
|
848
|
-
py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_prim.py')
|
|
849
|
-
tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_prim.py')
|
|
850
|
-
pyboost_import_header = generate_pyboost_import_header(yaml_str)
|
|
851
|
-
py_prim = generate_py_primitive(yaml_str, doc_str)
|
|
852
|
-
write_file(tmp_py_path, py_licence_str + ops_py_prim_header + pyboost_import_header + py_prim)
|
|
853
|
-
check_change_and_replace_file(py_path, tmp_py_path)
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
def generate_ops_def_file(work_path, yaml_str, doc_str, file_pre):
|
|
857
|
-
py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/{file_pre}_ops_def.py')
|
|
858
|
-
tmp_py_path = os.path.join(work_path, f'{K.PY_AUTO_GEN_PATH}/tmp_{file_pre}_ops_def.py')
|
|
859
|
-
py_func = generate_py_op_func(yaml_str, doc_str)
|
|
860
|
-
write_file(tmp_py_path, py_licence_str + ops_py_def_header + py_func)
|
|
861
|
-
check_change_and_replace_file(py_path, tmp_py_path)
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
def generate_ops_py_files(work_path, yaml_str, doc_str, file_pre):
|
|
865
|
-
"""
|
|
866
|
-
Generate ops python file from yaml.
|
|
867
|
-
"""
|
|
868
|
-
generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre)
|
|
869
|
-
generate_ops_def_file(work_path, yaml_str, doc_str, file_pre)
|
|
870
|
-
shutil.copy(os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_auto_generate_init.txt'),
|
|
871
|
-
os.path.join(work_path, K.PY_AUTO_GEN_PATH, "__init__.py"))
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
def generate_ops_cc_files(work_path, yaml_str):
|
|
875
|
-
"""
|
|
876
|
-
Generate ops c++ file from yaml.
|
|
877
|
-
"""
|
|
878
|
-
# ops_def
|
|
879
|
-
op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_def.cc')
|
|
880
|
-
tmp_op_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_def.cc')
|
|
881
|
-
cc_def_code = generate_cc_opdef(yaml_str)
|
|
882
|
-
write_file(tmp_op_cc_path, cc_license_str + cc_def_code)
|
|
883
|
-
check_change_and_replace_file(op_cc_path, tmp_op_cc_path)
|
|
884
|
-
|
|
885
|
-
# ops_primitive
|
|
886
|
-
op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_primitive.h')
|
|
887
|
-
tmp_op_prim_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_primitive.h')
|
|
888
|
-
op_prim_code = generate_op_prim_opdef(yaml_str)
|
|
889
|
-
write_file(tmp_op_prim_path, cc_license_str + op_prim_code)
|
|
890
|
-
check_change_and_replace_file(op_prim_path, tmp_op_prim_path)
|
|
891
|
-
|
|
892
|
-
# lite_h_ops
|
|
893
|
-
lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.h')
|
|
894
|
-
tmp_lite_ops_h_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.h')
|
|
895
|
-
lite_ops_h_code, lite_ops_cc_code = generate_lite_ops(yaml_str)
|
|
896
|
-
write_file(tmp_lite_ops_h_path, cc_license_str + lite_ops_h_code)
|
|
897
|
-
check_change_and_replace_file(lite_ops_h_path, tmp_lite_ops_h_path)
|
|
898
|
-
|
|
899
|
-
# lite_cc_ops
|
|
900
|
-
lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_lite_ops.cc')
|
|
901
|
-
tmp_lite_ops_cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_lite_ops.cc')
|
|
902
|
-
write_file(tmp_lite_ops_cc_path, cc_license_str + lite_ops_cc_code)
|
|
903
|
-
check_change_and_replace_file(lite_ops_cc_path, tmp_lite_ops_cc_path)
|
|
904
|
-
|
|
905
|
-
# ops_names
|
|
906
|
-
op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'gen_ops_name.h')
|
|
907
|
-
tmp_op_name_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH, 'tmp_gen_ops_name.h')
|
|
908
|
-
op_name_code = generate_op_name_opdef(yaml_str)
|
|
909
|
-
write_file(tmp_op_name_path, cc_license_str + op_name_code)
|
|
910
|
-
check_change_and_replace_file(op_name_path, tmp_op_name_path)
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
def generate_op_labels(yaml_data):
|
|
914
|
-
"""
|
|
915
|
-
Generate python labels
|
|
916
|
-
"""
|
|
917
|
-
gen_label_py = f"""op_labels = {{"""
|
|
918
|
-
for operator_name, operator_data in yaml_data.items():
|
|
919
|
-
labels = operator_data.get('labels')
|
|
920
|
-
if labels is not None:
|
|
921
|
-
class_name = _get_op_name(operator_name, operator_data)
|
|
922
|
-
gen_label_py += f"""
|
|
923
|
-
"{class_name}": {{"""
|
|
924
|
-
gen_label_py += f""", """.join([f""""{key}": {value}""" for key, value in labels.items()])
|
|
925
|
-
gen_label_py += f"""}},"""
|
|
926
|
-
gen_label_py += f"""
|
|
927
|
-
}}"""
|
|
928
|
-
return gen_label_py
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
def generate_op_arg_default_value(yaml_data):
|
|
932
|
-
"""
|
|
933
|
-
Generate python default value.
|
|
934
|
-
"""
|
|
935
|
-
default_py_header = f"""\"\"\"Operator labels and args default value.\"\"\"
|
|
936
|
-
from mindspore.common import dtype as mstype\n\n"""
|
|
937
|
-
|
|
938
|
-
gen_default_py = default_py_header + f"""op_args_default_value = {{"""
|
|
939
|
-
for operator_name, operator_data in yaml_data.items():
|
|
940
|
-
arg_default_dict = {}
|
|
941
|
-
args = operator_data.get('args')
|
|
942
|
-
for arg_name, arg_info in args.items():
|
|
943
|
-
arg_default = arg_info.get('default')
|
|
944
|
-
if arg_default is not None:
|
|
945
|
-
arg_default_dict[arg_name] = arg_default
|
|
946
|
-
if arg_default_dict:
|
|
947
|
-
class_name = _get_op_name(operator_name, operator_data)
|
|
948
|
-
gen_default_py += f"""
|
|
949
|
-
"{class_name}": {{"""
|
|
950
|
-
gen_default_py += f""", """.join([f""""{key}": {value}""" for key, value in arg_default_dict.items()])
|
|
951
|
-
gen_default_py += f"""}},"""
|
|
952
|
-
gen_default_py += f"""
|
|
953
|
-
}}"""
|
|
954
|
-
return gen_default_py
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
def generate_create_instance_helper_file(work_path, yaml_str):
|
|
958
|
-
"""
|
|
959
|
-
Generate C++ helper file from yaml.
|
|
960
|
-
"""
|
|
961
|
-
dst_dir = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
|
|
962
|
-
op_py_path = os.path.join(dst_dir, 'cpp_create_prim_instance_helper.py')
|
|
963
|
-
tmp_op_py_path = os.path.join(dst_dir, 'tmp_cpp_create_prim_instance_helper.py')
|
|
964
|
-
py_labels = generate_op_labels(yaml_str)
|
|
965
|
-
py_arg_default = generate_op_arg_default_value(yaml_str)
|
|
966
|
-
write_file(tmp_op_py_path, py_licence_str + "\n" + py_arg_default + "\n\n" + py_labels + "\n")
|
|
967
|
-
check_change_and_replace_file(op_py_path, tmp_op_py_path)
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
def generate_aclnn_reg_code(yaml_data):
|
|
971
|
-
"""generate aclnn register code"""
|
|
972
|
-
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
973
|
-
work_path = os.path.join(current_path, '../../../../')
|
|
974
|
-
ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, "ops.yaml")
|
|
975
|
-
yaml_str = gen_utils.safe_load_yaml(ops_yaml_path)
|
|
976
|
-
|
|
977
|
-
reg_code = f"""
|
|
978
|
-
#include "{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_mod.h"
|
|
979
|
-
|
|
980
|
-
namespace mindspore {{
|
|
981
|
-
namespace kernel {{
|
|
982
|
-
"""
|
|
983
|
-
for operator_name, operator_data in yaml_data.items():
|
|
984
|
-
dispatch = operator_data.get("dispatch")
|
|
985
|
-
if not dispatch or not dispatch.get("enable"):
|
|
986
|
-
continue
|
|
987
|
-
Ascend = dispatch.get("Ascend")
|
|
988
|
-
if Ascend is not None: # KernelMod is provided by yaml, don't auto generate it.
|
|
989
|
-
continue
|
|
990
|
-
_, _, none_tensor_exist = get_dtypes(operator_data)
|
|
991
|
-
if none_tensor_exist:
|
|
992
|
-
gen_aclnn_kernel(operator_name, yaml_str, auto=True)
|
|
993
|
-
continue
|
|
994
|
-
class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
|
|
995
|
-
op_class = operator_data.get("class")
|
|
996
|
-
if op_class and op_class.get("name") is not None:
|
|
997
|
-
class_name = op_class.get("name")
|
|
998
|
-
inputs_outputs_num = len(operator_data.get("args")) + len(operator_data.get("returns"))
|
|
999
|
-
aclnn_name = AclnnUtils.get_aclnn_interface(class_name)
|
|
1000
|
-
reg_code += f"""
|
|
1001
|
-
MS_ACLNN_COMMON_KERNEL_FACTORY_REG({class_name}, {aclnn_name}, {inputs_outputs_num});"""
|
|
1002
|
-
reg_code += f"""
|
|
1003
|
-
}} // namespace kernel
|
|
1004
|
-
}} // namespace mindspore
|
|
1005
|
-
"""
|
|
1006
|
-
return reg_code
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
def generate_aclnn_reg_file(work_path, yaml_str):
|
|
1010
|
-
"""
|
|
1011
|
-
Generate nnacl kernelmod register
|
|
1012
|
-
"""
|
|
1013
|
-
tmp_register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/tmp_aclnn_kernel_register.cc'
|
|
1014
|
-
register_file = work_path + f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_register_auto.cc'
|
|
1015
|
-
reg_code = generate_aclnn_reg_code(yaml_str)
|
|
1016
|
-
write_file(tmp_register_file, cc_license_str + reg_code)
|
|
1017
|
-
check_change_and_replace_file(register_file, tmp_register_file)
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
def generate_arg_handler_files(work_path):
|
|
1021
|
-
"""
|
|
1022
|
-
Generate arg handler files.
|
|
1023
|
-
"""
|
|
1024
|
-
dst_dir = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
|
|
1025
|
-
src_arg_handler_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'arg_handler.py')
|
|
1026
|
-
dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py')
|
|
1027
|
-
tmp_dst_arg_handler_path = os.path.join(dst_dir, 'tmp_gen_arg_handler.py')
|
|
1028
|
-
if not os.path.exists(dst_dir):
|
|
1029
|
-
os.makedirs(dst_dir, mode=0o700)
|
|
1030
|
-
shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path)
|
|
1031
|
-
check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path)
|
|
1032
|
-
|
|
1033
|
-
src_arg_dtype_cast_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'arg_dtype_cast.py')
|
|
1034
|
-
dst_arg_dtype_cast_path = os.path.join(dst_dir, 'gen_arg_dtype_cast.py')
|
|
1035
|
-
tmp_arg_dtype_cast_path = os.path.join(dst_dir, 'tmp_arg_dtype_cast.py')
|
|
1036
|
-
shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
|
|
1037
|
-
check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
def get_view_ops(yaml_data):
|
|
1041
|
-
"""
|
|
1042
|
-
Get ops with view: True
|
|
1043
|
-
"""
|
|
1044
|
-
view_ops = []
|
|
1045
|
-
for operator_name, operator_data in yaml_data.items():
|
|
1046
|
-
class_name = _get_op_name(operator_name, operator_data)
|
|
1047
|
-
view = operator_data.get("view")
|
|
1048
|
-
if view:
|
|
1049
|
-
view_ops.append(class_name + "View")
|
|
1050
|
-
return view_ops
|
|
30
|
+
module_generators = [
|
|
31
|
+
generate_ops_py_files, # generate ops python files
|
|
32
|
+
generate_ops_def_files, # generate ops definition files
|
|
33
|
+
gen_pyboost_code, # generate pyboost code
|
|
34
|
+
generate_aclnn_reg_file, # generate aclnn kernelmod register
|
|
35
|
+
generate_api_files # generate api definition files
|
|
36
|
+
]
|
|
1051
37
|
|
|
1052
38
|
|
|
1053
39
|
def main():
|
|
1054
|
-
|
|
1055
|
-
work_path = os.path.join(current_path, '../../../../')
|
|
1056
|
-
|
|
1057
|
-
# merge ops yaml
|
|
1058
|
-
ops_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops.yaml')
|
|
1059
|
-
doc_yaml_path = os.path.join(work_path, K.PY_OPS_GEN_PATH, 'ops_doc.yaml')
|
|
1060
|
-
|
|
1061
|
-
ops_yaml_dir_path = os.path.join(work_path, K.MS_YAML_PATH)
|
|
1062
|
-
infer_ops_yaml_dir_path = os.path.join(ops_yaml_dir_path, "infer")
|
|
1063
|
-
doc_yaml_dir_path = os.path.join(ops_yaml_dir_path, "doc")
|
|
1064
|
-
merge_files(ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
|
|
1065
|
-
merge_files_append(infer_ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
|
|
1066
|
-
merge_files(doc_yaml_dir_path, doc_yaml_path, '*doc.yaml')
|
|
1067
|
-
|
|
1068
|
-
# make auto_generate dir
|
|
1069
|
-
cc_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
1070
|
-
pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True)
|
|
1071
|
-
|
|
1072
|
-
# generate arg_handler files
|
|
1073
|
-
generate_arg_handler_files(work_path)
|
|
1074
|
-
|
|
1075
|
-
# read ops definition str and doc str
|
|
1076
|
-
ops_yaml_str = safe_load_yaml(ops_yaml_path)
|
|
1077
|
-
doc_yaml_str = safe_load_yaml(doc_yaml_path)
|
|
1078
|
-
|
|
1079
|
-
# generate ops python files
|
|
1080
|
-
generate_ops_py_files(work_path, ops_yaml_str, doc_yaml_str, "gen")
|
|
40
|
+
resource_mgr = prepare_resources()
|
|
1081
41
|
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
# generate create prim instance helper file
|
|
1085
|
-
generate_create_instance_helper_file(work_path, ops_yaml_str)
|
|
1086
|
-
# get view extra ops
|
|
1087
|
-
extra_ops = get_view_ops(ops_yaml_str)
|
|
1088
|
-
# generate pyboost code
|
|
1089
|
-
gen_pyboost_code(work_path, ops_yaml_str, doc_yaml_str, extra_ops)
|
|
1090
|
-
# generate aclnn kernelmod register
|
|
1091
|
-
generate_aclnn_reg_file(work_path, ops_yaml_str)
|
|
42
|
+
for generator in module_generators:
|
|
43
|
+
generator(resource_mgr)
|
|
1092
44
|
|
|
45
|
+
gen_utils.clear_obsolete_auto_gen_files()
|
|
1093
46
|
|
|
1094
47
|
if __name__ == "__main__":
|
|
1095
48
|
try:
|
|
1096
49
|
main()
|
|
1097
|
-
# pylint: disable=broad-except
|
|
1098
|
-
except Exception as e:
|
|
50
|
+
except Exception as e: # pylint: disable=broad-except
|
|
1099
51
|
logging.critical("Auto generate failed, err info: %s", e)
|
|
52
|
+
raise e
|