mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +98 -21
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -3
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +31 -15
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +285 -191
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +11 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +100 -66
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +0 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/cell.py +126 -19
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +47 -13
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1659 -290
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +19 -16
- mindspore/ops/operations/custom_ops.py +11 -55
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +146 -308
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +24 -33
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -1
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +2 -6
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +31 -10
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -15
- mindspore/train/serialization.py +462 -76
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +362 -238
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Module for generating C++ operator definition files.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import math
|
|
21
|
+
|
|
22
|
+
import gen_constants as K
|
|
23
|
+
import gen_utils
|
|
24
|
+
|
|
25
|
+
# refactored
|
|
26
|
+
from op_proto import OpProto
|
|
27
|
+
import template
|
|
28
|
+
|
|
29
|
+
from base_generator import BaseGenerator
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OpsDefCcGenerator(BaseGenerator):
|
|
33
|
+
"""
|
|
34
|
+
Generates C++ definition files for operators.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self):
|
|
38
|
+
"""
|
|
39
|
+
Initializes templates for generating C++ operator definitions.
|
|
40
|
+
"""
|
|
41
|
+
self.include_template = template.Template("""#include "${path}/${operator_name}.h\"\n""")
|
|
42
|
+
self.func_impl_declaration_template = template.Template("${class_name}FuncImpl g${class_name}FuncImpl;")
|
|
43
|
+
self.empty_func_impl_declaration_template = template.Template("static OpFuncImpl g${class_name}FuncImpl;")
|
|
44
|
+
self.func_impl_define_template = template.Template("g${class_name}FuncImpl")
|
|
45
|
+
self.OP_PROTO_TEMPLATE = template.OP_PROTO_TEMPLATE
|
|
46
|
+
self.CC_OPS_DEF_TEMPLATE = template.Template(K.CC_OPS_DEF)
|
|
47
|
+
|
|
48
|
+
def generate(self, work_path, op_protos):
|
|
49
|
+
"""
|
|
50
|
+
Generates C++ code for operator definitions and saves it to a file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
work_path (str): The directory to save the generated files.
|
|
54
|
+
op_protos (list): A list of operator prototypes.
|
|
55
|
+
"""
|
|
56
|
+
gen_cc_list = list()
|
|
57
|
+
gen_include_list = list()
|
|
58
|
+
gen_deprecated_cc_list = list()
|
|
59
|
+
|
|
60
|
+
for op_proto in op_protos:
|
|
61
|
+
operator_name = op_proto.op_name
|
|
62
|
+
class_name = op_proto.op_class.name
|
|
63
|
+
if "deprecated" not in operator_name:
|
|
64
|
+
gen_include_list.append(self.include_template.replace(path=K.MS_OPS_FUNC_IMPL_PATH,
|
|
65
|
+
operator_name=operator_name))
|
|
66
|
+
func_impl_declaration_str = self.func_impl_declaration_template.replace(class_name=class_name)
|
|
67
|
+
else:
|
|
68
|
+
func_impl_declaration_str = self.empty_func_impl_declaration_template.replace(class_name=class_name)
|
|
69
|
+
func_impl_define = self.func_impl_define_template.replace(class_name=class_name)
|
|
70
|
+
|
|
71
|
+
# process input
|
|
72
|
+
args_dict, cc_index_str, input_args_str = process_input_args(op_proto)
|
|
73
|
+
|
|
74
|
+
# Process outputs.
|
|
75
|
+
return_args_str = get_cc_op_def_return(args_dict, op_proto)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
inputs_args = self.process_args(op_proto.op_args)
|
|
79
|
+
signature_code = generate_cc_op_signature(op_proto.op_args_signature, inputs_args)
|
|
80
|
+
enable_dispatch = "true" if op_proto.op_dispatch and op_proto.op_dispatch.enable else "false"
|
|
81
|
+
is_view = "true" if op_proto.op_view else "false"
|
|
82
|
+
is_graph_view = "true" if op_proto.op_graph_view else "false"
|
|
83
|
+
op_def_cc = self.OP_PROTO_TEMPLATE.replace(class_name=class_name,
|
|
84
|
+
input_args=input_args_str,
|
|
85
|
+
return_args=return_args_str,
|
|
86
|
+
signatures=signature_code,
|
|
87
|
+
indexes=cc_index_str,
|
|
88
|
+
enable_dispatch=enable_dispatch,
|
|
89
|
+
is_view=is_view,
|
|
90
|
+
is_graph_view=is_graph_view,
|
|
91
|
+
func_impl_declaration=func_impl_declaration_str,
|
|
92
|
+
func_impl_define=func_impl_define)
|
|
93
|
+
if op_proto.op_view:
|
|
94
|
+
view_op_def = op_def_cc.replace(class_name, class_name + "View")
|
|
95
|
+
op_def_cc += view_op_def
|
|
96
|
+
|
|
97
|
+
if "deprecated" not in operator_name:
|
|
98
|
+
gen_cc_list.append(op_def_cc)
|
|
99
|
+
else:
|
|
100
|
+
gen_deprecated_cc_list.append(op_def_cc)
|
|
101
|
+
|
|
102
|
+
op_size = len(gen_include_list)
|
|
103
|
+
max_op_size_in_one_file = 300
|
|
104
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
105
|
+
for numbering in range(math.ceil(op_size / max_op_size_in_one_file)):
|
|
106
|
+
gen_include = ''.join(
|
|
107
|
+
gen_include_list[numbering*max_op_size_in_one_file: (numbering+1)*max_op_size_in_one_file])
|
|
108
|
+
gen_cc = ''.join(
|
|
109
|
+
gen_cc_list[numbering*max_op_size_in_one_file: (numbering+1)*max_op_size_in_one_file])
|
|
110
|
+
cc_ops_def = self.CC_OPS_DEF_TEMPLATE.replace(auto_generate_path=K.MS_OP_DEF_AUTO_GENERATE_PATH,
|
|
111
|
+
gen_include=gen_include,
|
|
112
|
+
gen_cc_code=gen_cc)
|
|
113
|
+
|
|
114
|
+
file_name = f"gen_ops_def_{chr(ord('a') + numbering)}.cc"
|
|
115
|
+
ops_def_cc_file_str = template.CC_LICENSE_STR + cc_ops_def
|
|
116
|
+
gen_utils.save_file(save_path, file_name, ops_def_cc_file_str)
|
|
117
|
+
|
|
118
|
+
deprecated_cc_ops_def = self.CC_OPS_DEF_TEMPLATE.replace(auto_generate_path=K.MS_OP_DEF_AUTO_GENERATE_PATH,
|
|
119
|
+
gen_include='',
|
|
120
|
+
gen_cc_code=''.join(gen_deprecated_cc_list))
|
|
121
|
+
file_name = "gen_deprecated_ops_def.cc"
|
|
122
|
+
deprecated_ops_def_cc_file_str = template.CC_LICENSE_STR + deprecated_cc_ops_def
|
|
123
|
+
gen_utils.save_file(save_path, file_name,
|
|
124
|
+
deprecated_ops_def_cc_file_str)
|
|
125
|
+
|
|
126
|
+
def process_args(self, op_args):
|
|
127
|
+
"""
|
|
128
|
+
Processes operator arguments to extract input names.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
op_args (list): A list of operator arguments.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
list: A list of input argument names.
|
|
135
|
+
"""
|
|
136
|
+
inputs_name = []
|
|
137
|
+
for arg in op_args:
|
|
138
|
+
if not arg.is_prim_init:
|
|
139
|
+
inputs_name.append(arg.arg_name)
|
|
140
|
+
return inputs_name
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def process_input_args(op_proto: OpProto):
|
|
144
|
+
"""
|
|
145
|
+
Processes input arguments for C++ code generation.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
op_proto (OpProto): The operator prototype.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
tuple: A tuple containing processed argument data.
|
|
152
|
+
"""
|
|
153
|
+
cc_index_str = ''
|
|
154
|
+
input_args_str = ''
|
|
155
|
+
args_dict = {}
|
|
156
|
+
op_args = op_proto.op_args
|
|
157
|
+
for i, op_arg in enumerate(op_args):
|
|
158
|
+
arg_name = op_arg.arg_name
|
|
159
|
+
args_dict[arg_name] = i
|
|
160
|
+
cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
|
|
161
|
+
dtype = op_arg.arg_dtype
|
|
162
|
+
cc_dtype_str = gen_utils.convert_dtype_str(dtype)
|
|
163
|
+
|
|
164
|
+
is_prim_init = 1 if op_arg.is_prim_init else 0
|
|
165
|
+
arg_handler_str = op_arg.arg_handler
|
|
166
|
+
|
|
167
|
+
type_cast = op_arg.type_cast
|
|
168
|
+
type_cast_str = "" if type_cast is None else \
|
|
169
|
+
", ".join('DT_' + type.replace('[', '_').replace(']', '').upper() for type in type_cast)
|
|
170
|
+
|
|
171
|
+
# default: None is regarded as an optional argument.
|
|
172
|
+
is_optional_str = "true" if op_arg.default == "None" else "false"
|
|
173
|
+
|
|
174
|
+
input_args_str += f"""\n {{/*.arg_name_=*/"{arg_name}", /*.arg_dtype_=*/{cc_dtype_str}, """ + \
|
|
175
|
+
f"""/*.as_init_arg_=*/{is_prim_init}, /*.arg_handler_=*/"{arg_handler_str}", """ + \
|
|
176
|
+
f"""/*.cast_dtype_ =*/{{{type_cast_str}}}, /*.is_optional_=*/{is_optional_str}}},"""
|
|
177
|
+
return args_dict, cc_index_str, input_args_str
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_cc_op_def_return(args_dict, op_proto: OpProto):
|
|
181
|
+
"""
|
|
182
|
+
Generates return argument strings for C++ operator definition.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
args_dict (dict): A dictionary mapping argument names to indexes.
|
|
186
|
+
op_proto (OpProto): The operator prototype.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
str: A string containing return argument data.
|
|
190
|
+
"""
|
|
191
|
+
return_args_str = ''
|
|
192
|
+
returns = op_proto.op_returns
|
|
193
|
+
for return_item in returns:
|
|
194
|
+
return_name = return_item.arg_name
|
|
195
|
+
return_dtype = return_item.arg_dtype
|
|
196
|
+
ref_name = return_item.inplace
|
|
197
|
+
ref_index_str = args_dict.get(ref_name) if ref_name else -1
|
|
198
|
+
cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').upper()
|
|
199
|
+
return_args_str += f"""{{/*.arg_name_=*/"{return_name}", /*.arg_dtype_=*/{cc_return_type_str},
|
|
200
|
+
/*.inplace_input_index_=*/{ref_index_str}}},\n"""
|
|
201
|
+
return return_args_str
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def generate_cc_op_signature(args_signature, args_name):
|
|
205
|
+
"""
|
|
206
|
+
Generates C++ signature code for operator arguments.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
args_signature (dict): A dictionary containing argument signatures.
|
|
210
|
+
args_name (list): A list of argument names.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
str: A string containing the generated signature code.
|
|
214
|
+
"""
|
|
215
|
+
if args_signature is None:
|
|
216
|
+
return ''
|
|
217
|
+
signature_code = ''
|
|
218
|
+
|
|
219
|
+
# Init rw.
|
|
220
|
+
read_list, ref_list, write_list = gen_utils.init_args_signature_rw(args_signature)
|
|
221
|
+
|
|
222
|
+
# Init dtype group.
|
|
223
|
+
same_dtype_groups, _ = gen_utils.get_same_dtype_groups(args_signature, args_name)
|
|
224
|
+
for arg_name in args_name:
|
|
225
|
+
enum_rw = signature_get_rw_label_cc(arg_name, write_list, read_list, ref_list)
|
|
226
|
+
enum_dtype = signature_get_enum_dtype_cc(same_dtype_groups.get(arg_name))
|
|
227
|
+
signature = f"""Signature("{arg_name}", {enum_rw}, """ \
|
|
228
|
+
f""" SignatureEnumKind::kKindPositionalKeyword, nullptr, {enum_dtype}),\n """
|
|
229
|
+
signature_code += signature
|
|
230
|
+
return signature_code
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def signature_get_rw_label_cc(rw_op_name, write_list, read_list, ref_list):
|
|
234
|
+
"""
|
|
235
|
+
Determines the read-write label for a C++ signature.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
rw_op_name (str): The name of the read-write operation.
|
|
239
|
+
write_list (list): A list of write operations.
|
|
240
|
+
read_list (list): A list of read operations.
|
|
241
|
+
ref_list (list): A list of reference operations.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
str: The read-write label code.
|
|
245
|
+
"""
|
|
246
|
+
# Define a dictionary mapping operation names to their corresponding RW labels
|
|
247
|
+
rw_label_map = {
|
|
248
|
+
'kRWWrite': write_list,
|
|
249
|
+
'kRWRead': read_list,
|
|
250
|
+
'kRWRef': ref_list
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
# Initialize with the default label
|
|
254
|
+
rw_label = 'kRWDefault'
|
|
255
|
+
|
|
256
|
+
# Check each list to see if the operation name matches and update the label if it does
|
|
257
|
+
for label, names in rw_label_map.items():
|
|
258
|
+
if rw_op_name in names:
|
|
259
|
+
rw_label = label
|
|
260
|
+
break # Exit the loop once a match is found
|
|
261
|
+
|
|
262
|
+
return f'SignatureEnumRW::{rw_label}'
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def signature_get_enum_dtype_cc(index):
|
|
266
|
+
"""
|
|
267
|
+
Generates C++ enum data type code for a signature.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
index (int): The index of the data type.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
str: The enum data type code.
|
|
274
|
+
"""
|
|
275
|
+
enum_type = 'SignatureEnumDType::'
|
|
276
|
+
type_map = {0: 'kDType',
|
|
277
|
+
1: 'kDType1',
|
|
278
|
+
2: 'kDType2',
|
|
279
|
+
3: 'kDType3',
|
|
280
|
+
4: 'kDType4',
|
|
281
|
+
5: 'kDType5',
|
|
282
|
+
6: 'kDType6',
|
|
283
|
+
7: 'kDType7',
|
|
284
|
+
8: 'kDType8',
|
|
285
|
+
9: 'kDType9'}
|
|
286
|
+
if index in type_map:
|
|
287
|
+
return enum_type + type_map[index]
|
|
288
|
+
return enum_type + 'kDTypeEmptyDefaultValue'
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module defines the OpHeaderFileGenerator class for generating header files for operator definitions.
|
|
17
|
+
|
|
18
|
+
The generator creates C++ header files that declare external operator definitions based on operator prototypes
|
|
19
|
+
and any additional operators provided. This is useful for managing operator interfaces in a consistent way.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import os
|
|
23
|
+
|
|
24
|
+
import template
|
|
25
|
+
from template import Template
|
|
26
|
+
from gen_utils import save_file
|
|
27
|
+
import gen_constants as K
|
|
28
|
+
from base_generator import BaseGenerator
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OpsDefHGenerator(BaseGenerator):
|
|
32
|
+
"""
|
|
33
|
+
Generates header files for operator definitions.
|
|
34
|
+
|
|
35
|
+
This class is responsible for creating C++ header files that declare external operator definitions
|
|
36
|
+
using templates. It processes a list of operator prototypes and can also include additional operators
|
|
37
|
+
provided as extra arguments.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
"""Initializes the OpHeaderFileGenerator and its templates."""
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.extern_template = Template("OPS_API extern OpDef g${op_name};\n")
|
|
44
|
+
self.GEN_OPS_DEF_HEADER_TEMPLATE = template.GEN_OPS_DEF_HEADER_TEMPLATE
|
|
45
|
+
|
|
46
|
+
def generate(self, work_path, op_protos):
|
|
47
|
+
"""
|
|
48
|
+
Generates the operator definition header file and saves it to the specified path.
|
|
49
|
+
|
|
50
|
+
This method constructs the header content by creating extern declarations for each operator defined
|
|
51
|
+
in the provided operator prototypes and any additional operators specified. The generated content
|
|
52
|
+
is then saved to a C++ header file.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
work_path (str): The directory path where the generated header file will be saved.
|
|
56
|
+
op_protos (list): A list of operator prototypes containing information about the operators.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
None
|
|
60
|
+
"""
|
|
61
|
+
extern_str = ''
|
|
62
|
+
extra_ops = []
|
|
63
|
+
for op_proto in op_protos:
|
|
64
|
+
extern_str += self.extern_template.replace(op_name=op_proto.op_class.name)
|
|
65
|
+
if op_proto.op_view:
|
|
66
|
+
extra_ops.append(op_proto.op_class.name + "View")
|
|
67
|
+
for class_name in extra_ops or []:
|
|
68
|
+
extern_str += self.extern_template.replace(op_name=class_name)
|
|
69
|
+
|
|
70
|
+
ops_header_file = self.GEN_OPS_DEF_HEADER_TEMPLATE.replace(extern_variable=extern_str)
|
|
71
|
+
|
|
72
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
73
|
+
file_name = "gen_ops_def.h"
|
|
74
|
+
save_file(save_path, file_name, ops_header_file)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Module for generating C++ header files with operator name definitions.
|
|
17
|
+
|
|
18
|
+
This module defines the `OpsNameHGenerator` class, which produces C++ code to define
|
|
19
|
+
constants for operator names based on given prototypes.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import os
|
|
23
|
+
|
|
24
|
+
import gen_constants as K
|
|
25
|
+
import gen_utils
|
|
26
|
+
import pyboost_utils
|
|
27
|
+
|
|
28
|
+
# refactored
|
|
29
|
+
import template
|
|
30
|
+
|
|
31
|
+
from base_generator import BaseGenerator
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OpsNameHGenerator(BaseGenerator):
|
|
35
|
+
"""
|
|
36
|
+
Class for generating C++ header files containing operator name constants.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
"""
|
|
41
|
+
Initializes the OpsNameHGenerator instance.
|
|
42
|
+
"""
|
|
43
|
+
self.op_name_op_def_template = template.Template(K.OP_NAME_OP_DEF)
|
|
44
|
+
self.op_def_body_template = template.Template("""constexpr auto kName${k_name_op} = "${k_name_op}";\n""")
|
|
45
|
+
|
|
46
|
+
def generate(self, work_path, op_protos):
|
|
47
|
+
"""
|
|
48
|
+
Generates C++ code for operator names and saves it to a header file.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
work_path (str): The directory to save the generated files.
|
|
52
|
+
op_protos (list): A list of operator prototypes.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
None
|
|
56
|
+
"""
|
|
57
|
+
op_name_gen_list = []
|
|
58
|
+
for op_proto in op_protos:
|
|
59
|
+
k_name_op = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
|
|
60
|
+
op_name_gen_list.append(self.op_def_body_template.replace(k_name_op=k_name_op))
|
|
61
|
+
|
|
62
|
+
op_name_code = self.op_name_op_def_template.replace(ops_namespace_body=op_name_gen_list)
|
|
63
|
+
|
|
64
|
+
op_name_code = template.CC_LICENSE_STR + op_name_code
|
|
65
|
+
|
|
66
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
67
|
+
file_name = "gen_ops_name.h"
|
|
68
|
+
gen_utils.save_file(save_path, file_name, op_name_code)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Module for generating C++ header files for operator primitives.
|
|
17
|
+
|
|
18
|
+
This module defines the `OpsPrimitiveHGenerator` class, which creates C++ header files
|
|
19
|
+
containing definitions for operator primitives based on provided operator prototypes.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import os
|
|
23
|
+
|
|
24
|
+
import gen_constants as K
|
|
25
|
+
import gen_utils
|
|
26
|
+
import pyboost_utils
|
|
27
|
+
|
|
28
|
+
# refactored
|
|
29
|
+
import template
|
|
30
|
+
|
|
31
|
+
from base_generator import BaseGenerator
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OpsPrimitiveHGenerator(BaseGenerator):
|
|
35
|
+
"""
|
|
36
|
+
This class generates the header file for operator primitives.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
"""
|
|
41
|
+
Initializes the generator with templates for operator primitive definitions.
|
|
42
|
+
"""
|
|
43
|
+
self.op_prim_op_def_template = template.Template(K.OP_PRIM_OP_DEF)
|
|
44
|
+
self.op_def_template = template.Template(
|
|
45
|
+
"GVAR_DEF(PrimitivePtr, kPrim${k_name_op}, std::make_shared<Primitive>(ops::kName${k_name_op}))\n")
|
|
46
|
+
self.op_def_rw_template = template.Template(
|
|
47
|
+
"GVAR_DEF(PrimitivePtr, kPrim${k_name_op}, std::make_shared<Primitive>(ops::kName${k_name_op}, "
|
|
48
|
+
"true, kPrimTypeBuiltIn, true))\n")
|
|
49
|
+
|
|
50
|
+
def generate(self, work_path, op_protos):
|
|
51
|
+
"""
|
|
52
|
+
Generates the header file content for operator primitives and saves it.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
work_path (str): The directory to save the generated files.
|
|
56
|
+
op_protos (list): A list of operator prototypes.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
None
|
|
60
|
+
|
|
61
|
+
The method generates the content of the header file for each operator primitive
|
|
62
|
+
defined in the 'op_protos' list and saves it to the specified work path.
|
|
63
|
+
"""
|
|
64
|
+
ops_prim_gen_list = []
|
|
65
|
+
for op_proto in op_protos:
|
|
66
|
+
k_name_op = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
|
|
67
|
+
if op_proto.op_args_signature:
|
|
68
|
+
if op_proto.op_args_signature.rw_write:
|
|
69
|
+
ops_prim_gen_list.append(self.op_def_rw_template.replace(k_name_op=k_name_op))
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
ops_prim_gen_list.append(self.op_def_template.replace(k_name_op=k_name_op))
|
|
73
|
+
|
|
74
|
+
op_prim_op_def = self.op_prim_op_def_template.replace(auto_gen_path=K.MS_OP_DEF_AUTO_GENERATE_PATH,
|
|
75
|
+
ops_prim_gen=ops_prim_gen_list)
|
|
76
|
+
|
|
77
|
+
res_str = template.CC_LICENSE_STR + op_prim_op_def
|
|
78
|
+
|
|
79
|
+
save_path = os.path.join(work_path, K.MS_OP_DEF_AUTO_GENERATE_PATH)
|
|
80
|
+
file_name = "gen_ops_primitive.h"
|
|
81
|
+
gen_utils.save_file(save_path, file_name, res_str)
|