mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generate operator definitions from ops.yaml
|
|
17
|
+
"""
|
|
18
|
+
import copy
|
|
19
|
+
|
|
20
|
+
from common import gen_constants as K
|
|
21
|
+
from resources.resource_list import ResourceType
|
|
22
|
+
|
|
23
|
+
from .ops_def_cc_generator import OpsDefCcGenerator
|
|
24
|
+
from .ops_def_h_generator import OpsDefHGenerator
|
|
25
|
+
from .ops_name_h_generator import OpsNameHGenerator
|
|
26
|
+
from .ops_primitive_h_generator import OpsPrimitiveHGenerator
|
|
27
|
+
from .lite_ops_cpp_generator import LiteOpsCcGenerator, LiteOpsHGenerator
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def call_ops_def_cc_generator(work_path, op_protos):
|
|
31
|
+
generator = OpsDefCcGenerator()
|
|
32
|
+
generator.generate(work_path, op_protos)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def call_ops_def_h_generator(work_path, op_protos):
|
|
36
|
+
generator = OpsDefHGenerator()
|
|
37
|
+
generator.generate(work_path, op_protos)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def call_ops_primitive_h_generator(work_path, op_protos):
|
|
41
|
+
generator = OpsPrimitiveHGenerator()
|
|
42
|
+
generator.generate(work_path, op_protos)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def call_lite_ops_h_generator(work_path, op_protos):
|
|
46
|
+
h_generator = LiteOpsHGenerator()
|
|
47
|
+
h_generator.generate(work_path, op_protos)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def call_lite_ops_cc_generator(work_path, op_protos):
|
|
51
|
+
generator = LiteOpsCcGenerator()
|
|
52
|
+
generator.generate(work_path, op_protos)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def call_ops_name_h_generator(work_path, op_protos):
|
|
56
|
+
h_generator = OpsNameHGenerator()
|
|
57
|
+
h_generator.generate(work_path, op_protos)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_tensor_op_protos_with_deprecated(func_protos, op_protos):
|
|
61
|
+
"""
|
|
62
|
+
Get op_protos with deprecated op_protos from func_protos.
|
|
63
|
+
"""
|
|
64
|
+
tensor_op_protos = copy.deepcopy(op_protos)
|
|
65
|
+
for _, item in func_protos.items():
|
|
66
|
+
for func_proto in item:
|
|
67
|
+
op_name = func_proto.op_proto.op_name
|
|
68
|
+
if "deprecated" in func_proto.op_proto.op_name:
|
|
69
|
+
func_proto.op_proto.op_class.name = ''.join(word.capitalize() for word in op_name.split('_'))
|
|
70
|
+
if func_proto.op_proto.op_name[-1] == '_':
|
|
71
|
+
func_proto.op_proto.op_class.name += '_'
|
|
72
|
+
tensor_op_protos.append(func_proto.op_proto)
|
|
73
|
+
return tensor_op_protos
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def generate_ops_def_files(resource_mgr):
|
|
77
|
+
"""
|
|
78
|
+
Generate ops c++ file from yaml.
|
|
79
|
+
"""
|
|
80
|
+
work_path = K.WORK_DIR
|
|
81
|
+
op_protos = resource_mgr.get_resource(ResourceType.OP_PROTO)
|
|
82
|
+
tensor_method_protos = resource_mgr.get_resource(ResourceType.TENSOR_METHOD_PROTOS)
|
|
83
|
+
# for generate tensor method deprecated in graph mode
|
|
84
|
+
op_protos_with_deprecated = get_tensor_op_protos_with_deprecated(tensor_method_protos, op_protos)
|
|
85
|
+
call_ops_def_cc_generator(work_path, op_protos_with_deprecated)
|
|
86
|
+
call_ops_def_h_generator(work_path, op_protos_with_deprecated)
|
|
87
|
+
call_ops_primitive_h_generator(work_path, op_protos)
|
|
88
|
+
call_lite_ops_h_generator(work_path, op_protos)
|
|
89
|
+
call_lite_ops_cc_generator(work_path, op_protos)
|
|
90
|
+
call_ops_name_h_generator(work_path, op_protos)
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generates C++ header and source files for lite operations based on YAML configurations.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
import common.gen_constants as K
|
|
22
|
+
import common.gen_utils as gen_utils
|
|
23
|
+
import common.template as template
|
|
24
|
+
from common.base_generator import BaseGenerator
|
|
25
|
+
from pyboost import pyboost_utils
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
LITE_OPS_H = """
|
|
29
|
+
#ifndef MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
|
|
30
|
+
#define MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
|
|
31
|
+
|
|
32
|
+
#include <vector>
|
|
33
|
+
#include "ops/base_operator.h"
|
|
34
|
+
|
|
35
|
+
namespace mindspore::ops {
|
|
36
|
+
$ops_namespace_body
|
|
37
|
+
|
|
38
|
+
} // namespace mindspore::ops
|
|
39
|
+
#endif // MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
INC_OPS_HEAD = """
|
|
43
|
+
#include "$auto_gen_path/gen_ops_name_${ch}.h"
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
LITE_OPS_CC = """
|
|
47
|
+
#include "$auto_gen_path/gen_lite_ops.h"
|
|
48
|
+
${inc_ops_head_str}
|
|
49
|
+
#include "mindapi/helper.h"
|
|
50
|
+
#include "ops/primitive_c.h"
|
|
51
|
+
#include "ops/base_operator.h"
|
|
52
|
+
#include "abstract/abstract_value.h"
|
|
53
|
+
|
|
54
|
+
namespace mindspore::ops {
|
|
55
|
+
$ops_namespace_body
|
|
56
|
+
|
|
57
|
+
} // namespace mindspore::ops
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LiteOpsHGenerator(BaseGenerator):
|
|
62
|
+
"""
|
|
63
|
+
This class is responsible for generating the header file for lite operations.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self):
|
|
67
|
+
"""
|
|
68
|
+
Initializes the generator with the necessary templates for generating C++ header files.
|
|
69
|
+
"""
|
|
70
|
+
self.lite_ops_h_template = template.Template(LITE_OPS_H)
|
|
71
|
+
self.lite_ops_class_template = template.op_cc_template
|
|
72
|
+
self.arg_prim_init_template = template.Template("\n"
|
|
73
|
+
" void set_${arg_name}(const ${dtype} &${arg_name});\n"
|
|
74
|
+
" ${dtype} get_${arg_name}() const;")
|
|
75
|
+
|
|
76
|
+
def generate(self, work_path, op_protos):
|
|
77
|
+
"""
|
|
78
|
+
Generates the header file content for lite operations and saves it to the specified path.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
work_path (str): The directory where the generated files will be saved.
|
|
82
|
+
op_protos (list): A list of operator prototypes containing information about the operators.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
None
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
lite_ops_h_code_list = []
|
|
89
|
+
for op_proto in op_protos:
|
|
90
|
+
op_name = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
|
|
91
|
+
op_args = op_proto.op_args
|
|
92
|
+
arg_prim_init_str = ""
|
|
93
|
+
for op_arg in op_args:
|
|
94
|
+
if not op_arg.is_prim_init:
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
arg_name = op_arg.arg_name
|
|
98
|
+
dtype = trans_dtype_for_lite(op_arg.arg_dtype)
|
|
99
|
+
arg_prim_init_str += self.arg_prim_init_template.replace(arg_name=arg_name, dtype=dtype)
|
|
100
|
+
|
|
101
|
+
temp = self.lite_ops_class_template.replace(op_name=op_name, arg_prim_init_list=arg_prim_init_str)
|
|
102
|
+
lite_ops_h_code_list.append(temp)
|
|
103
|
+
|
|
104
|
+
lite_ops_h = self.lite_ops_h_template.replace(auto_gen_path=K.OP_DEF_AUTO_GENERATE_PATH,
|
|
105
|
+
ops_namespace_body=lite_ops_h_code_list)
|
|
106
|
+
|
|
107
|
+
res_str = template.CC_LICENSE_STR + lite_ops_h
|
|
108
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
109
|
+
file_name = "gen_lite_ops.h"
|
|
110
|
+
gen_utils.save_file(save_path, file_name, res_str)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class LiteOpsCcGenerator(BaseGenerator):
|
|
114
|
+
"""
|
|
115
|
+
This class is responsible for generating the source file for lite operations.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(self):
|
|
119
|
+
"""
|
|
120
|
+
Initializes the generator with the necessary templates for generating C++ source files.
|
|
121
|
+
"""
|
|
122
|
+
self.inc_ops_head_templat = template.Template(INC_OPS_HEAD)
|
|
123
|
+
self.lite_ops_cc_template = template.Template(LITE_OPS_CC)
|
|
124
|
+
self.op_template = template.op_template
|
|
125
|
+
self.register_primitive_c_template = template.Template("${op_name}::${op_name}():BaseOperator(kName${op_name}) {}\n"
|
|
126
|
+
"REGISTER_PRIMITIVE_C(kName${op_name}, ${op_name});\n"
|
|
127
|
+
"MIND_API_OPERATOR_IMPL(${op_name}, BaseOperator);\n\n")
|
|
128
|
+
|
|
129
|
+
def generate(self, work_path, op_protos):
|
|
130
|
+
"""
|
|
131
|
+
Generates the source file content for lite operations and saves it to the specified path.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
work_path (str): The directory where the generated files will be saved.
|
|
135
|
+
op_protos (list): A list of operation prototypes to generate content for.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
None
|
|
139
|
+
"""
|
|
140
|
+
lite_ops_cc_gen_list = []
|
|
141
|
+
inc_ops_head_list = set()
|
|
142
|
+
for op_proto in op_protos:
|
|
143
|
+
arg_prim_init_str = ""
|
|
144
|
+
op_name = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
|
|
145
|
+
op_args = op_proto.op_args
|
|
146
|
+
for op_arg in op_args:
|
|
147
|
+
if not op_arg.is_prim_init:
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
arg_name = op_arg.arg_name
|
|
151
|
+
dtype = trans_dtype_for_lite(op_arg.arg_dtype)
|
|
152
|
+
arg_prim_init_str += self.op_template.replace(op_name=op_name, arg_name=arg_name, dtype=dtype)
|
|
153
|
+
|
|
154
|
+
self.register_primitive_c_template.replace(op_name=op_name)
|
|
155
|
+
lite_ops_cc_gen_list.append(arg_prim_init_str + self.register_primitive_c_template.replace(op_name=op_name))
|
|
156
|
+
inc_ops_head_list.add(self.inc_ops_head_templat.replace(auto_gen_path=K.OP_DEF_AUTO_GENERATE_PATH,
|
|
157
|
+
ch=op_name[0].lower()))
|
|
158
|
+
sorted_inc_ops_head_str = sorted(inc_ops_head_list)
|
|
159
|
+
lite_ops_cc = self.lite_ops_cc_template.replace(auto_gen_path=K.OP_DEF_AUTO_GENERATE_PATH,
|
|
160
|
+
ops_namespace_body=lite_ops_cc_gen_list,
|
|
161
|
+
inc_ops_head_str=sorted_inc_ops_head_str)
|
|
162
|
+
|
|
163
|
+
res_str = template.CC_LICENSE_STR + lite_ops_cc
|
|
164
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
165
|
+
file_name = "gen_lite_ops.cc"
|
|
166
|
+
gen_utils.save_file(save_path, file_name, res_str)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def trans_dtype_for_lite(dtype):
|
|
170
|
+
"""
|
|
171
|
+
Translate the data type for lite usage based on the argument information.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
dtype (str): The original data type as a string.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
str: The translated data type suitable for lite usage.
|
|
178
|
+
"""
|
|
179
|
+
type_mappings = {
|
|
180
|
+
"str": "std::string",
|
|
181
|
+
"tuple[str]": "std::vector<std::string>",
|
|
182
|
+
"list[str]": "std::vector<std::string>",
|
|
183
|
+
"tuple[int]": "std::vector<int64_t>",
|
|
184
|
+
"list[int]": "std::vector<int64_t>",
|
|
185
|
+
"tuple[float]": "std::vector<float>",
|
|
186
|
+
"list[float]": "std::vector<float>",
|
|
187
|
+
"tuple[bool]": "std::vector<bool>",
|
|
188
|
+
"list[bool]": "std::vector<bool>",
|
|
189
|
+
"int": "int64_t"
|
|
190
|
+
}
|
|
191
|
+
return type_mappings.get(dtype, dtype)
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Module for generating C++ operator definition files.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import math
|
|
21
|
+
|
|
22
|
+
import common.gen_constants as K
|
|
23
|
+
import common.gen_utils as gen_utils
|
|
24
|
+
|
|
25
|
+
# refactored
|
|
26
|
+
from common.op_proto import OpProto
|
|
27
|
+
import common.template as template
|
|
28
|
+
|
|
29
|
+
from common.base_generator import BaseGenerator
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
CC_OPS_DEF = """
|
|
33
|
+
|
|
34
|
+
#include "$auto_generate_path/gen_ops_def.h"
|
|
35
|
+
#include "ir/signature.h"
|
|
36
|
+
$gen_include
|
|
37
|
+
|
|
38
|
+
namespace mindspore::ops {$gen_cc_code
|
|
39
|
+
} // namespace mindspore::ops
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class OpsDefCcGenerator(BaseGenerator):
|
|
44
|
+
"""
|
|
45
|
+
Generates C++ definition files for operators.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self):
|
|
49
|
+
"""
|
|
50
|
+
Initializes templates for generating C++ operator definitions.
|
|
51
|
+
"""
|
|
52
|
+
self.include_template = template.Template("""#include "${path}/${operator_name}.h\"\n""")
|
|
53
|
+
self.func_impl_declaration_template = template.Template("${class_name}FuncImpl g${class_name}FuncImpl;")
|
|
54
|
+
self.empty_func_impl_declaration_template = template.Template("static OpFuncImpl g${class_name}FuncImpl;")
|
|
55
|
+
self.func_impl_define_template = template.Template("g${class_name}FuncImpl")
|
|
56
|
+
self.OP_PROTO_TEMPLATE = template.OP_PROTO_TEMPLATE
|
|
57
|
+
self.CC_OPS_DEF_TEMPLATE = template.Template(CC_OPS_DEF)
|
|
58
|
+
|
|
59
|
+
def generate(self, work_path, op_protos):
|
|
60
|
+
"""
|
|
61
|
+
Generates C++ code for operator definitions and saves it to a file.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
work_path (str): The directory to save the generated files.
|
|
65
|
+
op_protos (list): A list of operator prototypes.
|
|
66
|
+
"""
|
|
67
|
+
gen_cc_list = list()
|
|
68
|
+
gen_include_list = list()
|
|
69
|
+
gen_deprecated_cc_list = list()
|
|
70
|
+
|
|
71
|
+
for op_proto in op_protos:
|
|
72
|
+
operator_name = op_proto.op_name
|
|
73
|
+
class_name = op_proto.op_class.name
|
|
74
|
+
if not op_proto.func_op:
|
|
75
|
+
gen_include_list.append(self.include_template.replace(path=K.MS_OPS_FUNC_IMPL_PATH,
|
|
76
|
+
operator_name=operator_name))
|
|
77
|
+
func_impl_declaration_str = self.func_impl_declaration_template.replace(class_name=class_name)
|
|
78
|
+
else:
|
|
79
|
+
func_impl_declaration_str = self.empty_func_impl_declaration_template.replace(class_name=class_name)
|
|
80
|
+
func_impl_define = self.func_impl_define_template.replace(class_name=class_name)
|
|
81
|
+
|
|
82
|
+
# process input
|
|
83
|
+
args_dict, cc_index_str, input_args_str = process_input_args(op_proto)
|
|
84
|
+
|
|
85
|
+
# Process outputs.
|
|
86
|
+
return_args_str = get_cc_op_def_return(args_dict, op_proto)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
inputs_args = self.process_args(op_proto.op_args)
|
|
90
|
+
signature_code = generate_cc_op_signature(op_proto.op_args_signature, inputs_args)
|
|
91
|
+
enable_dispatch = "true" if op_proto.op_dispatch and op_proto.op_dispatch.enable else "false"
|
|
92
|
+
is_view = "true" if op_proto.op_view else "false"
|
|
93
|
+
is_graph_view = "true" if op_proto.op_graph_view else "false"
|
|
94
|
+
op_def_cc = self.OP_PROTO_TEMPLATE.replace(class_name=class_name,
|
|
95
|
+
input_args=input_args_str,
|
|
96
|
+
return_args=return_args_str,
|
|
97
|
+
signatures=signature_code,
|
|
98
|
+
indexes=cc_index_str,
|
|
99
|
+
enable_dispatch=enable_dispatch,
|
|
100
|
+
is_view=is_view,
|
|
101
|
+
is_graph_view=is_graph_view,
|
|
102
|
+
func_impl_declaration=func_impl_declaration_str,
|
|
103
|
+
func_impl_define=func_impl_define)
|
|
104
|
+
if op_proto.op_view:
|
|
105
|
+
view_op_def = op_def_cc.replace(class_name, class_name + "View")
|
|
106
|
+
op_def_cc += view_op_def
|
|
107
|
+
|
|
108
|
+
if "deprecated" not in operator_name:
|
|
109
|
+
gen_cc_list.append(op_def_cc)
|
|
110
|
+
else:
|
|
111
|
+
gen_deprecated_cc_list.append(op_def_cc)
|
|
112
|
+
|
|
113
|
+
op_size = len(gen_include_list)
|
|
114
|
+
max_op_size_in_one_file = 300
|
|
115
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
116
|
+
for numbering in range(math.ceil(op_size / max_op_size_in_one_file)):
|
|
117
|
+
gen_include = ''.join(
|
|
118
|
+
gen_include_list[numbering*max_op_size_in_one_file: (numbering+1)*max_op_size_in_one_file])
|
|
119
|
+
gen_cc = ''.join(
|
|
120
|
+
gen_cc_list[numbering*max_op_size_in_one_file: (numbering+1)*max_op_size_in_one_file])
|
|
121
|
+
cc_ops_def = self.CC_OPS_DEF_TEMPLATE.replace(auto_generate_path=K.MS_OP_DEF_AUTO_GENERATE_PATH,
|
|
122
|
+
gen_include=gen_include,
|
|
123
|
+
gen_cc_code=gen_cc)
|
|
124
|
+
|
|
125
|
+
file_name = f"gen_ops_def_{chr(ord('a') + numbering)}.cc"
|
|
126
|
+
ops_def_cc_file_str = template.CC_LICENSE_STR + cc_ops_def
|
|
127
|
+
gen_utils.save_file(save_path, file_name, ops_def_cc_file_str)
|
|
128
|
+
|
|
129
|
+
deprecated_cc_ops_def = self.CC_OPS_DEF_TEMPLATE.replace(auto_generate_path=K.MS_OP_DEF_AUTO_GENERATE_PATH,
|
|
130
|
+
gen_include='',
|
|
131
|
+
gen_cc_code=''.join(gen_deprecated_cc_list))
|
|
132
|
+
file_name = "gen_deprecated_ops_def.cc"
|
|
133
|
+
deprecated_ops_def_cc_file_str = template.CC_LICENSE_STR + deprecated_cc_ops_def
|
|
134
|
+
gen_utils.save_file(save_path, file_name,
|
|
135
|
+
deprecated_ops_def_cc_file_str)
|
|
136
|
+
|
|
137
|
+
def process_args(self, op_args):
|
|
138
|
+
"""
|
|
139
|
+
Processes operator arguments to extract input names.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
op_args (list): A list of operator arguments.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
list: A list of input argument names.
|
|
146
|
+
"""
|
|
147
|
+
inputs_name = []
|
|
148
|
+
for arg in op_args:
|
|
149
|
+
if not arg.is_prim_init:
|
|
150
|
+
inputs_name.append(arg.arg_name)
|
|
151
|
+
return inputs_name
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def process_input_args(op_proto: OpProto):
|
|
155
|
+
"""
|
|
156
|
+
Processes input arguments for C++ code generation.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
op_proto (OpProto): The operator prototype.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
tuple: A tuple containing processed argument data.
|
|
163
|
+
"""
|
|
164
|
+
cc_index_str = ''
|
|
165
|
+
input_args_str = ''
|
|
166
|
+
args_dict = {}
|
|
167
|
+
op_args = op_proto.op_args
|
|
168
|
+
for i, op_arg in enumerate(op_args):
|
|
169
|
+
arg_name = op_arg.arg_name
|
|
170
|
+
args_dict[arg_name] = i
|
|
171
|
+
cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
|
|
172
|
+
dtype = op_arg.arg_dtype
|
|
173
|
+
cc_dtype_str = gen_utils.convert_dtype_str(dtype)
|
|
174
|
+
|
|
175
|
+
is_prim_init = 1 if op_arg.is_prim_init else 0
|
|
176
|
+
arg_handler_str = op_arg.arg_handler
|
|
177
|
+
|
|
178
|
+
type_cast = op_arg.type_cast
|
|
179
|
+
type_cast_str = "" if type_cast is None else \
|
|
180
|
+
", ".join('DT_' + type.replace('[', '_').replace(']', '').upper() for type in type_cast)
|
|
181
|
+
|
|
182
|
+
# default: None is regarded as an optional argument.
|
|
183
|
+
is_optional_str = "true" if op_arg.default == "None" else "false"
|
|
184
|
+
|
|
185
|
+
input_args_str += f"""\n {{/*.arg_name_=*/"{arg_name}", /*.arg_dtype_=*/{cc_dtype_str}, """ + \
|
|
186
|
+
f"""/*.as_init_arg_=*/{is_prim_init}, /*.arg_handler_=*/"{arg_handler_str}", """ + \
|
|
187
|
+
f"""/*.cast_dtype_ =*/{{{type_cast_str}}}, /*.is_optional_=*/{is_optional_str}}},"""
|
|
188
|
+
return args_dict, cc_index_str, input_args_str
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def get_cc_op_def_return(args_dict, op_proto: OpProto):
|
|
192
|
+
"""
|
|
193
|
+
Generates return argument strings for C++ operator definition.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
args_dict (dict): A dictionary mapping argument names to indexes.
|
|
197
|
+
op_proto (OpProto): The operator prototype.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
str: A string containing return argument data.
|
|
201
|
+
"""
|
|
202
|
+
return_args_str = ''
|
|
203
|
+
returns = op_proto.op_returns
|
|
204
|
+
for return_item in returns:
|
|
205
|
+
return_name = return_item.arg_name
|
|
206
|
+
return_dtype = return_item.arg_dtype
|
|
207
|
+
ref_name = return_item.inplace
|
|
208
|
+
ref_index_str = args_dict.get(ref_name) if ref_name else -1
|
|
209
|
+
cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').upper()
|
|
210
|
+
return_args_str += f"""{{/*.arg_name_=*/"{return_name}", /*.arg_dtype_=*/{cc_return_type_str},
|
|
211
|
+
/*.inplace_input_index_=*/{ref_index_str}}},\n"""
|
|
212
|
+
return return_args_str
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def generate_cc_op_signature(args_signature, args_name):
|
|
216
|
+
"""
|
|
217
|
+
Generates C++ signature code for operator arguments.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
args_signature (dict): A dictionary containing argument signatures.
|
|
221
|
+
args_name (list): A list of argument names.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
str: A string containing the generated signature code.
|
|
225
|
+
"""
|
|
226
|
+
if args_signature is None:
|
|
227
|
+
return ''
|
|
228
|
+
signature_code = ''
|
|
229
|
+
|
|
230
|
+
# Init rw.
|
|
231
|
+
read_list, ref_list, write_list = gen_utils.init_args_signature_rw(args_signature)
|
|
232
|
+
|
|
233
|
+
# Init dtype group.
|
|
234
|
+
same_dtype_groups, _ = gen_utils.get_same_dtype_groups(args_signature, args_name)
|
|
235
|
+
for arg_name in args_name:
|
|
236
|
+
enum_rw = signature_get_rw_label_cc(arg_name, write_list, read_list, ref_list)
|
|
237
|
+
enum_dtype = signature_get_enum_dtype_cc(same_dtype_groups.get(arg_name))
|
|
238
|
+
signature = f"""Signature("{arg_name}", {enum_rw}, """ \
|
|
239
|
+
f""" SignatureEnumKind::kKindPositionalKeyword, nullptr, {enum_dtype}),\n """
|
|
240
|
+
signature_code += signature
|
|
241
|
+
return signature_code
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def signature_get_rw_label_cc(rw_op_name, write_list, read_list, ref_list):
|
|
245
|
+
"""
|
|
246
|
+
Determines the read-write label for a C++ signature.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
rw_op_name (str): The name of the read-write operation.
|
|
250
|
+
write_list (list): A list of write operations.
|
|
251
|
+
read_list (list): A list of read operations.
|
|
252
|
+
ref_list (list): A list of reference operations.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
str: The read-write label code.
|
|
256
|
+
"""
|
|
257
|
+
# Define a dictionary mapping operation names to their corresponding RW labels
|
|
258
|
+
rw_label_map = {
|
|
259
|
+
'kRWWrite': write_list,
|
|
260
|
+
'kRWRead': read_list,
|
|
261
|
+
'kRWRef': ref_list
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
# Initialize with the default label
|
|
265
|
+
rw_label = 'kRWDefault'
|
|
266
|
+
|
|
267
|
+
# Check each list to see if the operation name matches and update the label if it does
|
|
268
|
+
for label, names in rw_label_map.items():
|
|
269
|
+
if rw_op_name in names:
|
|
270
|
+
rw_label = label
|
|
271
|
+
break # Exit the loop once a match is found
|
|
272
|
+
|
|
273
|
+
return f'SignatureEnumRW::{rw_label}'
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def signature_get_enum_dtype_cc(index):
|
|
277
|
+
"""
|
|
278
|
+
Generates C++ enum data type code for a signature.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
index (int): The index of the data type.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
str: The enum data type code.
|
|
285
|
+
"""
|
|
286
|
+
enum_type = 'SignatureEnumDType::'
|
|
287
|
+
type_map = {0: 'kDType',
|
|
288
|
+
1: 'kDType1',
|
|
289
|
+
2: 'kDType2',
|
|
290
|
+
3: 'kDType3',
|
|
291
|
+
4: 'kDType4',
|
|
292
|
+
5: 'kDType5',
|
|
293
|
+
6: 'kDType6',
|
|
294
|
+
7: 'kDType7',
|
|
295
|
+
8: 'kDType8',
|
|
296
|
+
9: 'kDType9'}
|
|
297
|
+
if index in type_map:
|
|
298
|
+
return enum_type + type_map[index]
|
|
299
|
+
return enum_type + 'kDTypeEmptyDefaultValue'
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module defines the OpHeaderFileGenerator class for generating header files for operator definitions.
|
|
17
|
+
|
|
18
|
+
The generator creates C++ header files that declare external operator definitions based on operator prototypes
|
|
19
|
+
and any additional operators provided. This is useful for managing operator interfaces in a consistent way.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import os
|
|
23
|
+
|
|
24
|
+
import common.template as template
|
|
25
|
+
from common.template import Template
|
|
26
|
+
from common.gen_utils import save_file
|
|
27
|
+
import common.gen_constants as K
|
|
28
|
+
from common.base_generator import BaseGenerator
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OpsDefHGenerator(BaseGenerator):
|
|
32
|
+
"""
|
|
33
|
+
Generates header files for operator definitions.
|
|
34
|
+
|
|
35
|
+
This class is responsible for creating C++ header files that declare external operator definitions
|
|
36
|
+
using templates. It processes a list of operator prototypes and can also include additional operators
|
|
37
|
+
provided as extra arguments.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
"""Initializes the OpHeaderFileGenerator and its templates."""
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.extern_template = Template("OPS_API extern OpDef g${op_name};\n")
|
|
44
|
+
self.GEN_OPS_DEF_HEADER_TEMPLATE = template.GEN_OPS_DEF_HEADER_TEMPLATE
|
|
45
|
+
|
|
46
|
+
def generate(self, work_path, op_protos):
|
|
47
|
+
"""
|
|
48
|
+
Generates the operator definition header file and saves it to the specified path.
|
|
49
|
+
|
|
50
|
+
This method constructs the header content by creating extern declarations for each operator defined
|
|
51
|
+
in the provided operator prototypes and any additional operators specified. The generated content
|
|
52
|
+
is then saved to a C++ header file.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
work_path (str): The directory path where the generated header file will be saved.
|
|
56
|
+
op_protos (list): A list of operator prototypes containing information about the operators.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
None
|
|
60
|
+
"""
|
|
61
|
+
extern_str = ''
|
|
62
|
+
extra_ops = []
|
|
63
|
+
for op_proto in op_protos:
|
|
64
|
+
extern_str += self.extern_template.replace(op_name=op_proto.op_class.name)
|
|
65
|
+
if op_proto.op_view:
|
|
66
|
+
extra_ops.append(op_proto.op_class.name + "View")
|
|
67
|
+
for class_name in extra_ops or []:
|
|
68
|
+
extern_str += self.extern_template.replace(op_name=class_name)
|
|
69
|
+
|
|
70
|
+
ops_header_file = self.GEN_OPS_DEF_HEADER_TEMPLATE.replace(extern_variable=extern_str)
|
|
71
|
+
|
|
72
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
73
|
+
file_name = "gen_ops_def.h"
|
|
74
|
+
save_file(save_path, file_name, ops_header_file)
|