mindspore 2.4.1__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +99 -25
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +30 -27
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +10 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +292 -193
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +28 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +125 -65
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +3 -2
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
- mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
- mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +1 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +142 -21
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/conv.py +3 -0
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +55 -23
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/__init__.py +22 -0
- mindspore/nn/utils/init.py +73 -0
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1664 -294
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +23 -17
- mindspore/ops/operations/custom_ops.py +7 -49
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +153 -310
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +44 -34
- mindspore/parallel/_cell_wrapper.py +22 -3
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +18 -13
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +37 -15
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -14
- mindspore/train/serialization.py +461 -72
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +368 -242
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generates C++ registration code for ACL NN kernels based on operator prototypes.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import logging
|
|
21
|
+
import re
|
|
22
|
+
|
|
23
|
+
import gen_constants as K
|
|
24
|
+
import gen_utils
|
|
25
|
+
import pyboost_utils
|
|
26
|
+
|
|
27
|
+
import template
|
|
28
|
+
|
|
29
|
+
from base_generator import BaseGenerator
|
|
30
|
+
from gen_aclnn_implement import gen_aclnn_kernel
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AclnnKernelRegisterAutoCcGenerator(BaseGenerator):
|
|
34
|
+
"""Generates ACL NN kernel registration code for Ascend devices."""
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.aclnn_reg_code_template = template.Template(K.ACLNN_REG_CODE)
|
|
38
|
+
|
|
39
|
+
def generate(self, work_path, op_protos):
|
|
40
|
+
"""
|
|
41
|
+
Generates registration code for ACL NN kernels and saves it to a file.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
work_path (str): The directory to save the generated file.
|
|
45
|
+
op_protos (list): List of operator prototypes.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
None
|
|
49
|
+
"""
|
|
50
|
+
aclnn_reg_code = []
|
|
51
|
+
for op_proto in op_protos:
|
|
52
|
+
if not op_proto.op_dispatch or not op_proto.op_dispatch.enable:
|
|
53
|
+
continue
|
|
54
|
+
if op_proto.op_dispatch.ascend != 'default': # KernelMod is provided by yaml, don't auto generate it.
|
|
55
|
+
continue
|
|
56
|
+
if check_op_registed(op_proto.op_name):
|
|
57
|
+
logging.warning("Kernel {%s} is already registered.", op_proto.op_name)
|
|
58
|
+
continue
|
|
59
|
+
_, _, none_tensor_exist = pyboost_utils.get_dtypes(op_proto)
|
|
60
|
+
if none_tensor_exist:
|
|
61
|
+
# gen operator aclnn kernel c++ files
|
|
62
|
+
gen_aclnn_kernel(op_proto, auto=True)
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
class_name = op_proto.op_class.name
|
|
66
|
+
inputs_outputs_num = len(op_proto.op_args) + len(op_proto.op_returns)
|
|
67
|
+
aclnn_name = pyboost_utils.AclnnUtils.get_aclnn_interface(class_name)
|
|
68
|
+
aclnn_reg_code.append(
|
|
69
|
+
f"MS_ACLNN_COMMON_KERNEL_FACTORY_REG({class_name}, {aclnn_name}, {inputs_outputs_num});\n")
|
|
70
|
+
|
|
71
|
+
reg_code = self.aclnn_reg_code_template.replace(ops_gen_kernel_path=K.MS_OPS_KERNEL_PATH,
|
|
72
|
+
aclnn_reg_code=aclnn_reg_code)
|
|
73
|
+
res_str = template.CC_LICENSE_STR + reg_code
|
|
74
|
+
|
|
75
|
+
save_path = os.path.join(work_path, f"{K.MS_OPS_KERNEL_PATH}/ascend/opapi/auto_generate")
|
|
76
|
+
file_name = "aclnn_kernel_register_auto.cc"
|
|
77
|
+
gen_utils.save_file(save_path, file_name, res_str)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_registed_ops(file_path=f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/'):
|
|
81
|
+
'''get registered ops by search files'''
|
|
82
|
+
# default search in 'ops/kernel/ascend/opapi/'
|
|
83
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
84
|
+
work_path = os.path.join(current_path, '../../../../')
|
|
85
|
+
search_path = os.path.join(work_path, file_path)
|
|
86
|
+
ret = []
|
|
87
|
+
try:
|
|
88
|
+
for root_path, _, files in os.walk(search_path):
|
|
89
|
+
for file_name in files:
|
|
90
|
+
if file_name == 'aclnn_kernel_register_auto.cc':
|
|
91
|
+
continue
|
|
92
|
+
with open(os.path.join(root_path, file_name), "r") as f:
|
|
93
|
+
file_context = f.read()
|
|
94
|
+
search_re = re.search(r"(?<=KERNEL_FACTORY_REG\()\w+(?=,)", file_context)
|
|
95
|
+
if search_re:
|
|
96
|
+
ret.append(search_re.group())
|
|
97
|
+
except OSError:
|
|
98
|
+
logging.warning("Something wrong in check op registered.")
|
|
99
|
+
return ret
|
|
100
|
+
return ret
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
registed_ops = get_registed_ops()
|
|
104
|
+
manual_registed_ops = get_registed_ops(f'{K.MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn/')
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def check_op_registed(op_name, manual=False):
|
|
108
|
+
'''if op already registered return true'''
|
|
109
|
+
class_name = ''.join(word.capitalize() for word in op_name.split('_'))
|
|
110
|
+
return (class_name in manual_registed_ops) if manual else (class_name in registed_ops)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generates mindspore/common/_tensor_docs.py that attaches docs to tensor func APIs when import mindspore
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import gen_constants as K
|
|
21
|
+
from gen_utils import save_file
|
|
22
|
+
import template
|
|
23
|
+
from template import Template
|
|
24
|
+
from base_generator import BaseGenerator
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AddTensorDocsGenerator(BaseGenerator):
|
|
28
|
+
"""
|
|
29
|
+
This class is responsible for generating a helper file that enable users to view the docstrings of Tensor func APIs.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self):
|
|
33
|
+
self.ADD_TENSOR_DOCS_TEMPLATE = template.ADD_TENSOR_DOCS_TEMPLATE
|
|
34
|
+
self.attach_single_docstr_template = Template('attach_docstr("${api_name}", r"""${docstr}""")')
|
|
35
|
+
|
|
36
|
+
def generate(self, work_path, tensor_docs_data):
|
|
37
|
+
"""
|
|
38
|
+
Generates the content for the helper file and saves it to the specified path.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
work_path (str): The directory where the generated file will be saved.
|
|
42
|
+
tensor_docs_data (dict): A dict mapping from Tensor func API names to their docstrings.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
None
|
|
46
|
+
"""
|
|
47
|
+
add_doc_statements = []
|
|
48
|
+
for api_name, tensor_doc in tensor_docs_data.items():
|
|
49
|
+
single_add_doc_statement = self.attach_single_docstr_template.replace(api_name=api_name,
|
|
50
|
+
docstr=tensor_doc['description'])
|
|
51
|
+
single_add_doc_statement += template.NEW_LINE
|
|
52
|
+
add_doc_statements.append(single_add_doc_statement)
|
|
53
|
+
_tensor_docs_py_str = self.ADD_TENSOR_DOCS_TEMPLATE.replace(add_doc_statements=add_doc_statements)
|
|
54
|
+
save_file(os.path.join(work_path, K.ADD_TENSOR_DOCS_PY_PATH), "_tensor_docs.py", _tensor_docs_py_str)
|
|
@@ -116,67 +116,6 @@ def to_2d_paddings(op_name, arg_name, pad):
|
|
|
116
116
|
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
def to_paddings(op_name, arg_name, pad):
|
|
120
|
-
"""
|
|
121
|
-
convert paddings: int -> tuple[int*4].
|
|
122
|
-
"""
|
|
123
|
-
if isinstance(pad, int):
|
|
124
|
-
return (pad,) * 4
|
|
125
|
-
if isinstance(pad, (tuple, list)):
|
|
126
|
-
return pad
|
|
127
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def to_3d_kernel_size(op_name, arg_name, kernel_size):
|
|
131
|
-
"""
|
|
132
|
-
convert 3d kernel_size: int/tuple[int*6] -> tuple[int*3].
|
|
133
|
-
"""
|
|
134
|
-
if isinstance(kernel_size, int):
|
|
135
|
-
return (kernel_size, kernel_size, kernel_size)
|
|
136
|
-
if isinstance(kernel_size, (tuple, list)):
|
|
137
|
-
if len(kernel_size) == 5:
|
|
138
|
-
return (kernel_size[2], kernel_size[3], kernel_size[4])
|
|
139
|
-
return kernel_size
|
|
140
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size))
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def to_3d_strides(op_name, arg_name, stride):
|
|
144
|
-
"""
|
|
145
|
-
convert 3d stride: int/tuple[int*6] -> tuple[int*3].
|
|
146
|
-
"""
|
|
147
|
-
if isinstance(stride, int):
|
|
148
|
-
return (stride, stride, stride)
|
|
149
|
-
if isinstance(stride, (tuple, list)):
|
|
150
|
-
if len(stride) == 5:
|
|
151
|
-
return (stride[2], stride[3], stride[4])
|
|
152
|
-
return stride
|
|
153
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, stride))
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def to_3d_dilations(op_name, arg_name, dilation):
|
|
157
|
-
"""
|
|
158
|
-
convert 3d dilation: int/tuple[int*6] -> tuple[int*3].
|
|
159
|
-
"""
|
|
160
|
-
if isinstance(dilation, int):
|
|
161
|
-
return (dilation, dilation, dilation)
|
|
162
|
-
if isinstance(dilation, (tuple, list)):
|
|
163
|
-
if len(dilation) == 5:
|
|
164
|
-
return (dilation[2], dilation[3], dilation[4])
|
|
165
|
-
return dilation
|
|
166
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, dilation))
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def to_3d_paddings(op_name, arg_name, pad):
|
|
170
|
-
"""
|
|
171
|
-
convert 3d paddings: int -> tuple[int*6].
|
|
172
|
-
"""
|
|
173
|
-
if isinstance(pad, int):
|
|
174
|
-
return (pad,) * 6
|
|
175
|
-
if isinstance(pad, (tuple, list)):
|
|
176
|
-
return pad
|
|
177
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
178
|
-
|
|
179
|
-
|
|
180
119
|
def generator_handler(op_name, arg_name, inputs):
|
|
181
120
|
"""
|
|
182
121
|
convert constant value in tuple to tensor
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module provides a generator class for creating C++ implementation files for AutoGrad functionality.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
import template
|
|
22
|
+
from template import Template
|
|
23
|
+
import gen_constants as K
|
|
24
|
+
from gen_utils import save_file
|
|
25
|
+
from base_generator import BaseGenerator
|
|
26
|
+
from pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AutoGradImplGenerator(BaseGenerator):
|
|
30
|
+
"""
|
|
31
|
+
Generates C++ implementation files for the AutoGrad functionality based on operator prototypes.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the AutoGrad implementation generator with templates for code generation.
|
|
37
|
+
"""
|
|
38
|
+
self.AUTO_GRAD_IMPL_CC_TEMPLATE = template.AUTO_GRAD_IMPL_CC_TEMPLATE
|
|
39
|
+
self.DO_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_GRAD_FUNCTION_BODY_TEMPLATE
|
|
40
|
+
self.auto_grad_reg_template = Template("const_cast<kernel::pyboost::${class_name}GradFunc&>(" + \
|
|
41
|
+
"kernel::pyboost::AutoGradFactory::Get()." + \
|
|
42
|
+
"ops_auto_grad_registers().${class_name}GradFuncObj) = " + \
|
|
43
|
+
"kernel::pyboost::${class_name}GradFunc(DoGrad${class_name});")
|
|
44
|
+
self.do_grad_op_args_with_type = Template(
|
|
45
|
+
"const kernel::pyboost::OpPtr &op, ${input_args_with_type}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def generate(self, work_path, op_protos):
|
|
49
|
+
"""
|
|
50
|
+
Generate the AutoGrad implementation file.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
work_path (str): The directory where the generated file should be saved.
|
|
54
|
+
op_protos (list): A list of operator prototypes used to generate the implementation.
|
|
55
|
+
"""
|
|
56
|
+
auto_grad_reg_list = []
|
|
57
|
+
do_grad_op_list = []
|
|
58
|
+
for op_proto in op_protos:
|
|
59
|
+
if op_proto.op_dispatch is None or op_proto.op_dispatch.is_comm_op:
|
|
60
|
+
continue
|
|
61
|
+
auto_grad_reg_list.append(self.auto_grad_reg_template.replace(class_name=op_proto.op_class.name))
|
|
62
|
+
do_grad_op_list.append(self._get_single_do_grad_op(op_proto))
|
|
63
|
+
pyboost_func_h_str = self.AUTO_GRAD_IMPL_CC_TEMPLATE.replace(do_grad_op=do_grad_op_list,
|
|
64
|
+
auto_grad_reg=auto_grad_reg_list)
|
|
65
|
+
save_path = os.path.join(work_path, K.PYBOOST_AUTO_GRAD_FUNC_GEN_PATH)
|
|
66
|
+
file_name = "auto_grad_impl.cc"
|
|
67
|
+
save_file(save_path, file_name, pyboost_func_h_str)
|
|
68
|
+
|
|
69
|
+
def _get_single_do_grad_op(self, op_proto):
|
|
70
|
+
"""
|
|
71
|
+
Generate the DoGrad function for a single operator prototype.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
op_proto: The operator prototype for which the DoGrad function is generated.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
str: The generated DoGrad function string.
|
|
78
|
+
"""
|
|
79
|
+
input_args_str = self._get_input_args(op_proto, False, False)
|
|
80
|
+
input_args_with_optional_str = self._get_input_args(op_proto, False, True)
|
|
81
|
+
input_args_with_type_str = self._get_input_args(op_proto, True, False)
|
|
82
|
+
multi_output_str = 'Multi' if is_op_multi_output(op_proto.op_returns) else ''
|
|
83
|
+
view_arg_str = self._get_view_str(op_proto.op_view, input_args_str)
|
|
84
|
+
grad_args_with_type_str = self.do_grad_op_args_with_type.replace(input_args_with_type=input_args_with_type_str)
|
|
85
|
+
op_def_name_str = "g" + op_proto.op_class.name
|
|
86
|
+
bprop_expander = "true" if op_proto.bprop_expander else "false"
|
|
87
|
+
return self.DO_GRAD_FUNCTION_BODY_TEMPLATE.replace(class_name=op_proto.op_class.name,
|
|
88
|
+
grad_args_with_type=grad_args_with_type_str,
|
|
89
|
+
grad_input_args=input_args_str,
|
|
90
|
+
grad_input_args_with_optional=input_args_with_optional_str,
|
|
91
|
+
is_multi=multi_output_str,
|
|
92
|
+
view_arg=view_arg_str,
|
|
93
|
+
op_def_name=op_def_name_str,
|
|
94
|
+
bprop_expander=bprop_expander)
|
|
95
|
+
|
|
96
|
+
def _get_input_args(self, op_proto, has_type, with_optional):
|
|
97
|
+
"""
|
|
98
|
+
Get the input arguments for the DoGrad function.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
op_proto: The operator prototype.
|
|
102
|
+
has_type (bool): Whether to include type information for the arguments.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
list: A list of input arguments for the DoGrad function.
|
|
106
|
+
"""
|
|
107
|
+
args_list = []
|
|
108
|
+
for op_arg in op_proto.op_args:
|
|
109
|
+
input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional_param(op_arg))
|
|
110
|
+
if has_type:
|
|
111
|
+
args_list.append(f"const {input_dtype} &{op_arg.arg_name}_tensor")
|
|
112
|
+
else:
|
|
113
|
+
if not with_optional and is_optional_param(op_arg):
|
|
114
|
+
args_list.append(f"OptionalToValue({op_arg.arg_name}_tensor)")
|
|
115
|
+
else:
|
|
116
|
+
args_list.append(f"{op_arg.arg_name}_tensor")
|
|
117
|
+
return args_list
|
|
118
|
+
|
|
119
|
+
def _get_view_str(self, is_view_op: bool, grad_args: list):
|
|
120
|
+
"""
|
|
121
|
+
Get the view argument string for a DoGrad function.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
is_view_op (bool): Whether the operator is a view operator.
|
|
125
|
+
grad_args (list): A list of gradient arguments.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
str: The view argument string.
|
|
129
|
+
"""
|
|
130
|
+
view_arg_str = ''
|
|
131
|
+
for i, grad_arg in enumerate(grad_args):
|
|
132
|
+
if is_view_op and i == 0:
|
|
133
|
+
view_arg_str = ", " + grad_arg
|
|
134
|
+
break
|
|
135
|
+
return view_arg_str
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
This module provides a generator class for creating C++ header files for AutoGrad registration functionality.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
import template
|
|
22
|
+
from template import Template
|
|
23
|
+
import gen_constants as K
|
|
24
|
+
from gen_utils import save_file
|
|
25
|
+
from base_generator import BaseGenerator
|
|
26
|
+
from pyboost_utils import is_optional_param, get_input_dtype
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AutoGradRegHeaderGenerator(BaseGenerator):
|
|
30
|
+
"""
|
|
31
|
+
Generates C++ header files for the AutoGrad registration functionality based on operator prototypes.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the AutoGrad registration header generator with templates for code generation.
|
|
37
|
+
"""
|
|
38
|
+
self.AUTO_GRAD_REG_H_TEMPLATE = template.AUTO_GRAD_REG_H_TEMPLATE
|
|
39
|
+
self.op_type_enum_template = Template("k${class_name} = ${enum_val},\n")
|
|
40
|
+
self.op_grad_func_template = Template("using ${class_name}GradFunc = std::function<void(${grad_func_args})>;")
|
|
41
|
+
self.op_grad_func_obj_template = Template("${class_name}GradFunc ${class_name}GradFuncObj;")
|
|
42
|
+
self.op_grad_func_args_template = Template(
|
|
43
|
+
"const kernel::pyboost::OpPtr &, ${input_tensor_prt_args}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def generate(self, work_path, op_protos):
|
|
47
|
+
"""
|
|
48
|
+
Generate the AutoGrad registration header file.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
work_path (str): The directory where the generated file should be saved.
|
|
52
|
+
op_protos (list): A list of operator prototypes used to generate the header.
|
|
53
|
+
"""
|
|
54
|
+
op_type_enum_list = []
|
|
55
|
+
op_grad_func_list = []
|
|
56
|
+
op_grad_func_obj_list = []
|
|
57
|
+
index = 0
|
|
58
|
+
for op_proto in op_protos:
|
|
59
|
+
if op_proto.op_dispatch is None or op_proto.op_dispatch.is_comm_op:
|
|
60
|
+
continue
|
|
61
|
+
op_type_enum_list.append(self.op_type_enum_template.replace(class_name=op_proto.op_class.name,
|
|
62
|
+
enum_val=index))
|
|
63
|
+
grad_func_args_with_type_str = self._get_grad_func_args_with_type_str(op_proto)
|
|
64
|
+
op_grad_func_list.append(self.op_grad_func_template.replace(class_name=op_proto.op_class.name,
|
|
65
|
+
grad_func_args=grad_func_args_with_type_str))
|
|
66
|
+
op_grad_func_obj_list.append(self.op_grad_func_obj_template.replace(class_name=op_proto.op_class.name))
|
|
67
|
+
index += 1
|
|
68
|
+
|
|
69
|
+
pyboost_func_h_str = self.AUTO_GRAD_REG_H_TEMPLATE.replace(op_enum=op_type_enum_list,
|
|
70
|
+
op_grad_func=op_grad_func_list,
|
|
71
|
+
op_grad_func_obj=op_grad_func_obj_list)
|
|
72
|
+
|
|
73
|
+
save_path = os.path.join(work_path, K.MS_OPS_KERNEL_FUNCTIONS_AUTO_GEN_PATH)
|
|
74
|
+
file_name = "auto_grad_op_reg.h"
|
|
75
|
+
save_file(save_path, file_name, pyboost_func_h_str)
|
|
76
|
+
|
|
77
|
+
def _get_grad_func_args_with_type_str(self, op_proto):
|
|
78
|
+
"""
|
|
79
|
+
Get the gradient function arguments with type information.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
op_proto: The operator prototype.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
str: A string of input tensor pointer arguments with types.
|
|
86
|
+
"""
|
|
87
|
+
input_tensor_prt_args_str = ""
|
|
88
|
+
for op_arg in op_proto.op_args:
|
|
89
|
+
is_optional = is_optional_param(op_arg)
|
|
90
|
+
input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional)
|
|
91
|
+
input_tensor_prt_args_str += f"const {input_dtype} &, "
|
|
92
|
+
|
|
93
|
+
return self.op_grad_func_args_template.replace(input_tensor_prt_args=input_tensor_prt_args_str.rstrip(', '))
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generates C++ helper files for primitive instance creation based on operator metadata.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
import gen_constants as K
|
|
22
|
+
import gen_utils
|
|
23
|
+
import pyboost_utils
|
|
24
|
+
|
|
25
|
+
# refactored
|
|
26
|
+
import template
|
|
27
|
+
|
|
28
|
+
from base_generator import BaseGenerator
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CppCreatePrimInstanceHelperGenerator(BaseGenerator):
|
|
32
|
+
"""
|
|
33
|
+
This class is responsible for generating a helper file that contains
|
|
34
|
+
operation labels and default values for creating primitive instances in C++.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self):
|
|
38
|
+
"""
|
|
39
|
+
Initializes the generator with templates for operation labels and default values.
|
|
40
|
+
"""
|
|
41
|
+
self.op_labels_template = template.op_labels_template
|
|
42
|
+
self.op_args_default_value_template = template.arg_default_value
|
|
43
|
+
self.op_label_template = template.Template(""" "$op_name": {$op_label_body},\n""")
|
|
44
|
+
self.op_arg_default_val_template = template.Template(""" "$op_name": {$op_arg_default_value},\n""")
|
|
45
|
+
|
|
46
|
+
def generate(self, work_path, op_protos):
|
|
47
|
+
"""
|
|
48
|
+
Generates the content for the helper file and saves it to the specified path.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
work_path (str): The directory where the generated file will be saved.
|
|
52
|
+
op_protos (list): A list of operation prototypes to generate content for.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
None
|
|
56
|
+
"""
|
|
57
|
+
py_arg_default = self.generate_op_arg_default_value(op_protos)
|
|
58
|
+
py_labels = self.generate_op_labels(op_protos)
|
|
59
|
+
res_str = (template.PY_LICENCE_STR + py_arg_default + py_labels)
|
|
60
|
+
|
|
61
|
+
save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
|
|
62
|
+
file_name = "cpp_create_prim_instance_helper.py"
|
|
63
|
+
gen_utils.save_file(save_path, file_name, res_str)
|
|
64
|
+
|
|
65
|
+
def generate_op_labels(self, op_protos):
|
|
66
|
+
"""
|
|
67
|
+
Generates a string containing labels for each operation.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
op_protos (list): A list of operation prototypes.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
str: A string representing the labels in the specified format.
|
|
74
|
+
"""
|
|
75
|
+
gen_label_list = []
|
|
76
|
+
for op_proto in op_protos:
|
|
77
|
+
labels = op_proto.op_labels
|
|
78
|
+
if labels is not None:
|
|
79
|
+
op_name = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
|
|
80
|
+
op_label_list = [f"\"{name}\": {value}" for name, value in labels.items()]
|
|
81
|
+
gen_label_list.append(self.op_label_template.replace(op_name=op_name, op_label_body=op_label_list))
|
|
82
|
+
|
|
83
|
+
return self.op_labels_template.replace(gen_label_py=gen_label_list)
|
|
84
|
+
|
|
85
|
+
def generate_op_arg_default_value(self, op_protos):
|
|
86
|
+
"""
|
|
87
|
+
Generates a string containing default values for each operation's arguments.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
op_protos (list): A list of operation prototypes.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
str: A string representing the default argument values in the specified format.
|
|
94
|
+
"""
|
|
95
|
+
gen_default_list = []
|
|
96
|
+
for op_proto in op_protos:
|
|
97
|
+
arg_default_dict = {}
|
|
98
|
+
for op_arg in op_proto.op_args:
|
|
99
|
+
arg_default = op_arg.default
|
|
100
|
+
if arg_default is not None:
|
|
101
|
+
arg_default_dict[op_arg.arg_name] = arg_default
|
|
102
|
+
if arg_default_dict:
|
|
103
|
+
op_name = pyboost_utils.get_op_name(op_proto.op_name, op_proto.op_class.name)
|
|
104
|
+
arg_default_list = [f"\"{key}\": {value}" for key, value in arg_default_dict.items()]
|
|
105
|
+
gen_default_list.append(self.op_arg_default_val_template.replace(op_name=op_name,
|
|
106
|
+
op_arg_default_value=arg_default_list))
|
|
107
|
+
|
|
108
|
+
return self.op_args_default_value_template.replace(gen_default_py=gen_default_list)
|