mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -20,14 +20,14 @@ import os
|
|
|
20
20
|
|
|
21
21
|
import common.gen_constants as K
|
|
22
22
|
import common.gen_utils as gen_utils
|
|
23
|
-
import common.
|
|
24
|
-
from common.base_generator import BaseGenerator
|
|
23
|
+
import common.template_utils as template
|
|
25
24
|
from common.op_proto import OpProto
|
|
26
|
-
from common.
|
|
25
|
+
from common.template_utils import Template
|
|
27
26
|
from pyboost import pyboost_utils
|
|
27
|
+
from op_def_py.base_op_prim_py_generator import BaseOpPrimPyGenerator, _generate_arg_handler, generate_py_op_deprecated
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
class OpPrimPyGenerator(
|
|
30
|
+
class OpPrimPyGenerator(BaseOpPrimPyGenerator):
|
|
31
31
|
"""
|
|
32
32
|
Generates Python code for primitive operators based on provided specifications.
|
|
33
33
|
"""
|
|
@@ -87,7 +87,7 @@ class OpPrimPyGenerator(BaseGenerator):
|
|
|
87
87
|
|
|
88
88
|
pyboost_import_header = self.generate_pyboost_import_header(op_protos)
|
|
89
89
|
res_str = template.PY_LICENSE_STR + \
|
|
90
|
-
|
|
90
|
+
template.OPS_PY_PRIM_HEADER + pyboost_import_header + gen_py
|
|
91
91
|
|
|
92
92
|
save_path = os.path.join(work_path, K.PY_AUTO_GEN_PATH)
|
|
93
93
|
file_name = f"{file_pre}_ops_prim.py"
|
|
@@ -111,113 +111,6 @@ class OpPrimPyGenerator(BaseGenerator):
|
|
|
111
111
|
pyboost_import_header += header
|
|
112
112
|
return pyboost_import_header
|
|
113
113
|
|
|
114
|
-
def _process_args(self, op_proto: OpProto):
|
|
115
|
-
"""
|
|
116
|
-
Processes operator arguments to categorize them for code generation.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
op_proto (OpProto): The operator prototype.
|
|
120
|
-
|
|
121
|
-
Returns:
|
|
122
|
-
tuple: A tuple containing processed arguments.
|
|
123
|
-
"""
|
|
124
|
-
inputs_name = []
|
|
125
|
-
args_name = []
|
|
126
|
-
args_assign = []
|
|
127
|
-
inputs_default = {}
|
|
128
|
-
init_args_with_default = []
|
|
129
|
-
args_handlers = {}
|
|
130
|
-
|
|
131
|
-
for arg in op_proto.op_args:
|
|
132
|
-
# step1: get args infos:
|
|
133
|
-
if arg.is_prim_init:
|
|
134
|
-
# step1.1: get args name:
|
|
135
|
-
args_name.append(arg.arg_name)
|
|
136
|
-
# step1.2: get args assign with default value:
|
|
137
|
-
if arg.default is not None:
|
|
138
|
-
init_args_with_default.append(f"""{arg.arg_name}={arg.default}""")
|
|
139
|
-
else:
|
|
140
|
-
init_args_with_default.append(f"""{arg.arg_name}""")
|
|
141
|
-
|
|
142
|
-
# step1.3: get args set prim arg expression:
|
|
143
|
-
assign_str = self._get_assign_str_by_type_it(op_proto.op_class.name, arg)
|
|
144
|
-
if arg.arg_handler:
|
|
145
|
-
assign_str = (
|
|
146
|
-
f' self._set_prim_arg_with_handler('
|
|
147
|
-
f'"{arg.arg_name}", {assign_str}, {arg.arg_handler})'
|
|
148
|
-
)
|
|
149
|
-
else:
|
|
150
|
-
assign_str = f""" self._set_prim_arg("{arg.arg_name}", {assign_str})"""
|
|
151
|
-
args_assign.append(assign_str)
|
|
152
|
-
# step2: get inputs infos:
|
|
153
|
-
else:
|
|
154
|
-
# step2.1: get inputs name:
|
|
155
|
-
inputs_name.append(arg.arg_name)
|
|
156
|
-
|
|
157
|
-
# step2.2: get default value of inputs:
|
|
158
|
-
if arg.default is not None:
|
|
159
|
-
inputs_default[arg.arg_name] = arg.default
|
|
160
|
-
|
|
161
|
-
# step2.3: get args_handler functions for inputs
|
|
162
|
-
if arg.arg_handler:
|
|
163
|
-
args_handlers[arg.arg_name] = arg.arg_handler
|
|
164
|
-
|
|
165
|
-
return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
|
|
166
|
-
|
|
167
|
-
def _get_assign_str_by_type_it(self, class_name, arg):
|
|
168
|
-
"""
|
|
169
|
-
Generates assignment string with type casting.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
class_name (str): The name of the class.
|
|
173
|
-
arg (OpArg): The operator argument.
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
str: A string representing the assignment.
|
|
177
|
-
"""
|
|
178
|
-
assign_str = ""
|
|
179
|
-
type_cast = arg.type_cast
|
|
180
|
-
if type_cast:
|
|
181
|
-
assign_str += f"type_it('{class_name}', '{arg.arg_name}', {arg.arg_name}, "
|
|
182
|
-
if len(type_cast) == 1:
|
|
183
|
-
assign_str += gen_utils.get_type_str(type_cast[0]) + ', '
|
|
184
|
-
else:
|
|
185
|
-
assign_str += '(' + ', '.join(gen_utils.get_type_str(ct) for ct in type_cast) + '), '
|
|
186
|
-
assign_str += gen_utils.get_type_str(arg.arg_dtype) + ')'
|
|
187
|
-
else:
|
|
188
|
-
assign_str = arg.arg_name
|
|
189
|
-
return assign_str
|
|
190
|
-
|
|
191
|
-
def _generate_class_desc(self, op_proto: OpProto, input_args, init_args, doc_dic):
|
|
192
|
-
"""
|
|
193
|
-
Generates a class description based on the operator prototype.
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
op_proto (OpProto): The operator prototype.
|
|
197
|
-
input_args (list): List of input argument names.
|
|
198
|
-
init_args (list): List of initialization argument names.
|
|
199
|
-
doc_dic (dict): Documentation dictionary.
|
|
200
|
-
|
|
201
|
-
Returns:
|
|
202
|
-
str: A string containing the class description.
|
|
203
|
-
"""
|
|
204
|
-
if op_proto.op_function and op_proto.op_function.disable:
|
|
205
|
-
# if function disabled, function name is equal to operator_name
|
|
206
|
-
return gen_utils.get_op_description(op_proto.op_name, doc_dic)
|
|
207
|
-
|
|
208
|
-
# If function is a released API, refer to the function doc.
|
|
209
|
-
init_args_str = ", ".join(init_args)
|
|
210
|
-
input_args_str = ", ".join(input_args)
|
|
211
|
-
args_str = ", ".join(input_args + init_args)
|
|
212
|
-
|
|
213
|
-
description_template = Template(template.PRIMITIVE_CLASS_DESC)
|
|
214
|
-
description_str = description_template.replace(class_name=op_proto.op_class.name,
|
|
215
|
-
init_args_str=init_args_str,
|
|
216
|
-
input_args_str=input_args_str,
|
|
217
|
-
func_name=op_proto.op_function.name,
|
|
218
|
-
args_str=args_str)
|
|
219
|
-
return description_str
|
|
220
|
-
|
|
221
114
|
def _generate_init_code(self, args_assign, init_args_with_default, op_proto: OpProto):
|
|
222
115
|
"""
|
|
223
116
|
Generates the __init__ method code for the operator primitive class.
|
|
@@ -242,50 +135,6 @@ class OpPrimPyGenerator(BaseGenerator):
|
|
|
242
135
|
init_code_str += f"\n"
|
|
243
136
|
return init_code_str
|
|
244
137
|
|
|
245
|
-
def _get_init_code(self, init_code, op_proto: OpProto):
|
|
246
|
-
"""
|
|
247
|
-
Generates additional initialization code for the operator primitive class.
|
|
248
|
-
|
|
249
|
-
Args:
|
|
250
|
-
init_code (str): Existing initialization code.
|
|
251
|
-
op_proto (OpProto): The operator prototype.
|
|
252
|
-
|
|
253
|
-
Returns:
|
|
254
|
-
str: A string containing additional initialization code.
|
|
255
|
-
"""
|
|
256
|
-
labels_dic = op_proto.op_labels
|
|
257
|
-
if labels_dic:
|
|
258
|
-
if init_code:
|
|
259
|
-
init_code += "\n"
|
|
260
|
-
init_code += "\n".join([f""" self.add_prim_attr("{k}", {v})""" for k, v in labels_dic.items()])
|
|
261
|
-
|
|
262
|
-
return init_code if init_code else f""" pass"""
|
|
263
|
-
|
|
264
|
-
def _generate_call_code(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
|
|
265
|
-
"""
|
|
266
|
-
Generates the __call__ method code for the operator primitive class.
|
|
267
|
-
|
|
268
|
-
Args:
|
|
269
|
-
args_handlers (dict): Dictionary of argument handlers.
|
|
270
|
-
init_args (list): List of initialization argument names.
|
|
271
|
-
inputs_args (list): List of input argument names.
|
|
272
|
-
inputs_default (dict): Dictionary of default input values.
|
|
273
|
-
op_proto (OpProto): The operator prototype.
|
|
274
|
-
|
|
275
|
-
Returns:
|
|
276
|
-
str: A string containing the __call__ method code.
|
|
277
|
-
"""
|
|
278
|
-
call_code_str = ""
|
|
279
|
-
call_args = []
|
|
280
|
-
for name in inputs_args:
|
|
281
|
-
call_args.append(f"{name}={inputs_default[name]}" if name in inputs_default else name)
|
|
282
|
-
call_method_args_str = ", ".join(call_args)
|
|
283
|
-
call_method_body_str = self._get_call_method_body_str(args_handlers, init_args, inputs_args, inputs_default,
|
|
284
|
-
op_proto)
|
|
285
|
-
call_code_str += f""" def __call__(self, {call_method_args_str}):"""
|
|
286
|
-
call_code_str += f"""{call_method_body_str}"""
|
|
287
|
-
return call_code_str
|
|
288
|
-
|
|
289
138
|
def _get_call_method_body_str(self, args_handlers, init_args, inputs_args, inputs_default, op_proto: OpProto):
|
|
290
139
|
"""
|
|
291
140
|
Generates the body of the __call__ method.
|
|
@@ -334,159 +183,3 @@ class OpPrimPyGenerator(BaseGenerator):
|
|
|
334
183
|
call_method_body_str += f"""
|
|
335
184
|
return super().__call__({call_args_list_str})\n"""
|
|
336
185
|
return call_method_body_str
|
|
337
|
-
|
|
338
|
-
def _generate_py_op_signature(self, op_proto: OpProto, args_name, args_default):
|
|
339
|
-
"""
|
|
340
|
-
Generates the __mindspore_signature__ for the operator.
|
|
341
|
-
|
|
342
|
-
Args:
|
|
343
|
-
op_proto (OpProto): The operator prototype.
|
|
344
|
-
args_name (list): List of argument names.
|
|
345
|
-
args_default (dict): Dictionary of default argument values.
|
|
346
|
-
|
|
347
|
-
Returns:
|
|
348
|
-
str: A string containing the __mindspore_signature__ code.
|
|
349
|
-
"""
|
|
350
|
-
op_name = op_proto.op_name
|
|
351
|
-
args_signature = op_proto.op_args_signature
|
|
352
|
-
|
|
353
|
-
if args_signature is None and not args_default:
|
|
354
|
-
return ''
|
|
355
|
-
|
|
356
|
-
signature_code = f"""\n __mindspore_signature__ = """
|
|
357
|
-
|
|
358
|
-
# Init rw.
|
|
359
|
-
read_list, ref_list, write_list = gen_utils.init_args_signature_rw(args_signature)
|
|
360
|
-
_check_signature_arg_valid(op_name, write_list, args_name)
|
|
361
|
-
_check_signature_arg_valid(op_name, read_list, args_name)
|
|
362
|
-
_check_signature_arg_valid(op_name, ref_list, args_name)
|
|
363
|
-
|
|
364
|
-
# Init dtype group.
|
|
365
|
-
same_dtype_groups, dtype_count = gen_utils.get_same_dtype_groups(args_signature, args_name)
|
|
366
|
-
_check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
|
|
367
|
-
|
|
368
|
-
# Only one dtype_group is set.
|
|
369
|
-
if dtype_count == 1 and not any([write_list, read_list, ref_list, args_default]):
|
|
370
|
-
signature_code += '('
|
|
371
|
-
for _ in range(len(args_name) - 1):
|
|
372
|
-
signature_code += 'sig.sig_dtype.T, '
|
|
373
|
-
signature_code += 'sig.sig_dtype.T)\n'
|
|
374
|
-
return signature_code
|
|
375
|
-
|
|
376
|
-
# Set sig.make_sig.
|
|
377
|
-
signature_code += f""" (\n"""
|
|
378
|
-
for arg_name in args_name:
|
|
379
|
-
signature_code += f""" sig.make_sig('{arg_name}'"""
|
|
380
|
-
signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
|
|
381
|
-
if arg_name in same_dtype_groups:
|
|
382
|
-
signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
|
|
383
|
-
if arg_name in args_default:
|
|
384
|
-
signature_code += f""", default=""" + str(args_default[arg_name])
|
|
385
|
-
signature_code += f"""),\n"""
|
|
386
|
-
signature_code += f""" )\n"""
|
|
387
|
-
return signature_code
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
|
|
391
|
-
"""
|
|
392
|
-
Validates that all signature arguments are present in the list of argument names.
|
|
393
|
-
|
|
394
|
-
Args:
|
|
395
|
-
op_name (str): The name of the operator.
|
|
396
|
-
sig_arg_names (list): List of signature argument names.
|
|
397
|
-
args_names (list): List of actual argument names.
|
|
398
|
-
|
|
399
|
-
Raises:
|
|
400
|
-
ValueError: If a signature argument is not found in the list of argument names.
|
|
401
|
-
"""
|
|
402
|
-
for sig_arg_name in sig_arg_names:
|
|
403
|
-
if sig_arg_name not in args_names:
|
|
404
|
-
raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
def signature_get_dtype_label(index):
|
|
408
|
-
"""
|
|
409
|
-
Generates the label for the data type in the signature.
|
|
410
|
-
|
|
411
|
-
Args:
|
|
412
|
-
index (int): The index of the data type.
|
|
413
|
-
|
|
414
|
-
Returns:
|
|
415
|
-
str: The label string for the data type.
|
|
416
|
-
"""
|
|
417
|
-
dtype_index = ''
|
|
418
|
-
if index > 0:
|
|
419
|
-
dtype_index = f"""{index}"""
|
|
420
|
-
return f"""dtype=sig.sig_dtype.T{dtype_index}"""
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
|
|
424
|
-
"""
|
|
425
|
-
Determines the read-write label for an argument in the signature.
|
|
426
|
-
|
|
427
|
-
Args:
|
|
428
|
-
arg_name (str): The name of the argument.
|
|
429
|
-
write_list (list): List of arguments that are writable.
|
|
430
|
-
read_list (list): List of arguments that are readable.
|
|
431
|
-
ref_list (list): List of arguments that are references.
|
|
432
|
-
|
|
433
|
-
Returns:
|
|
434
|
-
str: The read-write label for the argument.
|
|
435
|
-
"""
|
|
436
|
-
for rw_arg_name in write_list:
|
|
437
|
-
if rw_arg_name == arg_name:
|
|
438
|
-
return ', sig.sig_rw.RW_WRITE'
|
|
439
|
-
for read_arg_name in read_list:
|
|
440
|
-
if read_arg_name == arg_name:
|
|
441
|
-
return ', sig.sig_rw.RW_READ'
|
|
442
|
-
for ref_arg_name in ref_list:
|
|
443
|
-
if ref_arg_name == arg_name:
|
|
444
|
-
return ', sig.sig_rw.RW_REF'
|
|
445
|
-
return ''
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
def generate_py_op_deprecated(deprecated):
|
|
449
|
-
"""
|
|
450
|
-
Generates the deprecated decorator for an operator.
|
|
451
|
-
|
|
452
|
-
Args:
|
|
453
|
-
deprecated (dict): The deprecation information.
|
|
454
|
-
|
|
455
|
-
Returns:
|
|
456
|
-
str: A string containing the deprecated decorator.
|
|
457
|
-
"""
|
|
458
|
-
if deprecated is None:
|
|
459
|
-
return ''
|
|
460
|
-
version = deprecated.get("version")
|
|
461
|
-
if version is None:
|
|
462
|
-
raise ValueError("The version of deprecated can't be None.")
|
|
463
|
-
substitute = deprecated.get("substitute")
|
|
464
|
-
if substitute is None:
|
|
465
|
-
raise ValueError("The substitute of deprecated can't be None.")
|
|
466
|
-
use_substitute = deprecated.get("use_substitute")
|
|
467
|
-
if use_substitute is None:
|
|
468
|
-
raise ValueError("The use_substitute of deprecated can't be None.")
|
|
469
|
-
if use_substitute is not True and use_substitute is not False:
|
|
470
|
-
raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
|
|
471
|
-
|
|
472
|
-
deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n"""
|
|
473
|
-
return deprecated
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
|
|
477
|
-
"""
|
|
478
|
-
Generates the argument handler call for an argument.
|
|
479
|
-
|
|
480
|
-
Args:
|
|
481
|
-
class_name (str): The name of the class.
|
|
482
|
-
arg (str): The name of the argument.
|
|
483
|
-
arg_handler (str): The handler function for the argument.
|
|
484
|
-
is_optional (bool): Indicates whether the argument is optional.
|
|
485
|
-
|
|
486
|
-
Returns:
|
|
487
|
-
str: The argument handler call string.
|
|
488
|
-
"""
|
|
489
|
-
arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
|
|
490
|
-
if is_optional:
|
|
491
|
-
arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
|
|
492
|
-
return arg_handler_call
|
|
@@ -23,7 +23,7 @@ from common.template import Template
|
|
|
23
23
|
import common.gen_constants as K
|
|
24
24
|
from common.gen_utils import save_file
|
|
25
25
|
from common.base_generator import BaseGenerator
|
|
26
|
-
from pyboost.pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output
|
|
26
|
+
from pyboost.pyboost_utils import is_optional_param, get_input_dtype, is_op_multi_output, get_output_dtype
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class AutoGradImplGenerator(BaseGenerator):
|
|
@@ -38,6 +38,8 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
38
38
|
self.OP_DEF_INC_HEAD_TEMPLATE = template.OP_DEF_INC_HEAD_TEMPLATE
|
|
39
39
|
self.AUTO_GRAD_IMPL_CC_TEMPLATE = template.AUTO_GRAD_IMPL_CC_TEMPLATE
|
|
40
40
|
self.DO_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_GRAD_FUNCTION_BODY_TEMPLATE
|
|
41
|
+
self.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE
|
|
42
|
+
self.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE = template.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE
|
|
41
43
|
self.auto_grad_reg_template = Template("const_cast<kernel::pyboost::${class_name}GradFunc&>(" + \
|
|
42
44
|
"kernel::pyboost::AutoGradFactory::Get()." + \
|
|
43
45
|
"ops_auto_grad_registers().${class_name}GradFuncObj) = " + \
|
|
@@ -45,6 +47,9 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
45
47
|
self.do_grad_op_args_with_type = Template(
|
|
46
48
|
"const kernel::pyboost::OpPtr &op, ${input_args_with_type}"
|
|
47
49
|
)
|
|
50
|
+
self.do_grad_view_op_args_with_type = Template(
|
|
51
|
+
"${output_args_with_type}, ${input_args_with_type}"
|
|
52
|
+
)
|
|
48
53
|
|
|
49
54
|
def generate(self, work_path, op_protos):
|
|
50
55
|
"""
|
|
@@ -60,8 +65,13 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
60
65
|
for op_proto in op_protos:
|
|
61
66
|
if op_proto.op_dispatch is None:
|
|
62
67
|
continue
|
|
68
|
+
# the backward func of flatten_ext and t_ext are implemented by other view ops, just continue
|
|
69
|
+
if op_proto.op_view and not op_proto.bprop_expander:
|
|
70
|
+
continue
|
|
63
71
|
auto_grad_reg_list.append(self.auto_grad_reg_template.replace(class_name=op_proto.op_class.name))
|
|
64
|
-
|
|
72
|
+
do_single_grad_op_str = self._get_single_do_grad_view_op(op_proto)\
|
|
73
|
+
if op_proto.op_view else self._get_single_do_grad_op(op_proto)
|
|
74
|
+
do_grad_op_list.append(do_single_grad_op_str)
|
|
65
75
|
ops_inc_head_set.add(self.OP_DEF_INC_HEAD_TEMPLATE.replace(prefix_char=op_proto.op_class.name[0].lower()))
|
|
66
76
|
pyboost_func_h_str = self.AUTO_GRAD_IMPL_CC_TEMPLATE.replace(do_grad_op=do_grad_op_list,
|
|
67
77
|
auto_grad_reg=auto_grad_reg_list,
|
|
@@ -80,12 +90,11 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
80
90
|
Returns:
|
|
81
91
|
str: The generated DoGrad function string.
|
|
82
92
|
"""
|
|
83
|
-
input_args_str = self._get_input_args(op_proto, False, False,
|
|
84
|
-
input_args_with_optional_str = self._get_input_args(op_proto, False, True,
|
|
85
|
-
input_args_with_type_str = self._get_input_args(op_proto, True, False,
|
|
93
|
+
input_args_str = self._get_input_args(op_proto, False, False, False)
|
|
94
|
+
input_args_with_optional_str = self._get_input_args(op_proto, False, True, False)
|
|
95
|
+
input_args_with_type_str = self._get_input_args(op_proto, True, False, False)
|
|
86
96
|
inner_grad_args_with_type = self._get_input_args(op_proto, True, False, False)
|
|
87
97
|
multi_output_str = 'Multi' if is_op_multi_output(op_proto.op_returns) else ''
|
|
88
|
-
view_arg_str = self._get_view_str(op_proto.op_view, input_args_str)
|
|
89
98
|
grad_args_with_type_str = self.do_grad_op_args_with_type.replace(input_args_with_type=input_args_with_type_str)
|
|
90
99
|
inner_grad_args_with_type =\
|
|
91
100
|
self.do_grad_op_args_with_type.replace(input_args_with_type=inner_grad_args_with_type)
|
|
@@ -94,22 +103,62 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
94
103
|
FALSE = "false"
|
|
95
104
|
bprop_expander = TRUE if op_proto.bprop_expander else FALSE
|
|
96
105
|
non_differentiable = TRUE if op_proto.non_differentiable else FALSE
|
|
97
|
-
|
|
98
|
-
convert_basic_to_value = ''
|
|
99
|
-
else:
|
|
100
|
-
input_args_with_optional_str, convert_basic_to_value = self._get_convert_str(op_proto,
|
|
101
|
-
input_args_with_optional_str)
|
|
106
|
+
|
|
102
107
|
return self.DO_GRAD_FUNCTION_BODY_TEMPLATE.replace(class_name=op_proto.op_class.name,
|
|
103
108
|
inner_grad_args_with_type=inner_grad_args_with_type,
|
|
104
109
|
grad_args_with_type=grad_args_with_type_str,
|
|
105
110
|
grad_input_args=input_args_str,
|
|
106
111
|
grad_input_args_with_optional=input_args_with_optional_str,
|
|
107
112
|
is_multi=multi_output_str,
|
|
108
|
-
view_arg=view_arg_str,
|
|
109
113
|
op_def_name=op_def_name_str,
|
|
110
114
|
bprop_expander=bprop_expander,
|
|
111
|
-
non_differentiable=non_differentiable
|
|
112
|
-
|
|
115
|
+
non_differentiable=non_differentiable)
|
|
116
|
+
|
|
117
|
+
def _get_single_do_grad_view_op(self, op_proto):
|
|
118
|
+
"""
|
|
119
|
+
Generate the DoGrad function for a single view operator prototype.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
op_proto: The operator prototype for which the DoGrad function is generated.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
str: The generated DoGrad function string.
|
|
126
|
+
"""
|
|
127
|
+
input_args_str = self._get_input_args(op_proto, False, False, True)
|
|
128
|
+
input_args_with_optional_str = self._get_input_args(op_proto, False, True, True)
|
|
129
|
+
input_args_with_type_str = self._get_input_args(op_proto, True, False, True)
|
|
130
|
+
inner_grad_args_with_type = self._get_input_args(op_proto, True, False, False)
|
|
131
|
+
view_arg_str = self._get_view_str(input_args_str)
|
|
132
|
+
grad_args_with_type_str = self.do_grad_view_op_args_with_type\
|
|
133
|
+
.replace(input_args_with_type=input_args_with_type_str,
|
|
134
|
+
output_args_with_type=self._get_output_arg(op_proto))
|
|
135
|
+
inner_grad_args_with_type =\
|
|
136
|
+
self.do_grad_view_op_args_with_type.replace(output_args_with_type="const ValuePtr &output_value",
|
|
137
|
+
input_args_with_type=inner_grad_args_with_type)
|
|
138
|
+
op_def_name_str = "g" + op_proto.op_class.name
|
|
139
|
+
TRUE = "true"
|
|
140
|
+
FALSE = "false"
|
|
141
|
+
bprop_expander = TRUE if op_proto.bprop_expander else FALSE
|
|
142
|
+
non_differentiable = TRUE if op_proto.non_differentiable else FALSE
|
|
143
|
+
if op_proto.op_name in ["reshape", "expand_dims", "transpose", "slice_ext_view",\
|
|
144
|
+
"select_ext_view", "transpose_ext_view"]:
|
|
145
|
+
do_view_grad_function_body_tpl = self.DO_VIEW_CUSTOMIZE_GRAD_FUNCTION_BODY_TEMPLATE
|
|
146
|
+
convert_basic_to_value = ""
|
|
147
|
+
else:
|
|
148
|
+
do_view_grad_function_body_tpl = self.DO_VIEW_GRAD_FUNCTION_BODY_TEMPLATE
|
|
149
|
+
input_args_with_optional_str, convert_basic_to_value = self._get_convert_str(op_proto,
|
|
150
|
+
input_args_with_optional_str)
|
|
151
|
+
return do_view_grad_function_body_tpl.replace(class_name=op_proto.op_class.name,
|
|
152
|
+
inner_grad_args_with_type=inner_grad_args_with_type,
|
|
153
|
+
grad_args_with_type=grad_args_with_type_str,
|
|
154
|
+
grad_input_args=input_args_str,
|
|
155
|
+
grad_input_args_with_optional=input_args_with_optional_str,
|
|
156
|
+
view_arg=view_arg_str,
|
|
157
|
+
op_def_name=op_def_name_str,
|
|
158
|
+
bprop_expander=bprop_expander,
|
|
159
|
+
non_differentiable=non_differentiable,
|
|
160
|
+
convert_basic_to_value=convert_basic_to_value)
|
|
161
|
+
|
|
113
162
|
|
|
114
163
|
def _get_input_args(self, op_proto, has_type, with_optional, use_basic_type=False):
|
|
115
164
|
"""
|
|
@@ -134,6 +183,15 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
134
183
|
args_list.append(f"{op_arg.arg_name}_tensor")
|
|
135
184
|
return args_list
|
|
136
185
|
|
|
186
|
+
def _get_output_arg(self, op_proto):
|
|
187
|
+
# for view operators, the output is tensor or vector<tensor>
|
|
188
|
+
if len(op_proto.op_returns) != 1:
|
|
189
|
+
raise ValueError(f"the output of {op_proto.op_name} is not tensor, ",
|
|
190
|
+
"tuple[tensor] or list[tensor], which is not not as expected")
|
|
191
|
+
output_dtype = get_output_dtype(op_proto.op_returns[0].arg_dtype)
|
|
192
|
+
output_arg = f"const {output_dtype} &output"
|
|
193
|
+
return output_arg
|
|
194
|
+
|
|
137
195
|
def _get_convert_str(self, op_proto, args_name):
|
|
138
196
|
"""
|
|
139
197
|
Get the input convert func for the DoGrad function.
|
|
@@ -161,12 +219,11 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
161
219
|
args_name_list.append(out_arg_name)
|
|
162
220
|
return args_name_list, convert_funcs
|
|
163
221
|
|
|
164
|
-
def _get_view_str(self,
|
|
222
|
+
def _get_view_str(self, grad_args: list):
|
|
165
223
|
"""
|
|
166
224
|
Get the view argument string for a DoGrad function.
|
|
167
225
|
|
|
168
226
|
Args:
|
|
169
|
-
is_view_op (bool): Whether the operator is a view operator.
|
|
170
227
|
grad_args (list): A list of gradient arguments.
|
|
171
228
|
|
|
172
229
|
Returns:
|
|
@@ -174,7 +231,7 @@ class AutoGradImplGenerator(BaseGenerator):
|
|
|
174
231
|
"""
|
|
175
232
|
view_arg_str = ''
|
|
176
233
|
for i, grad_arg in enumerate(grad_args):
|
|
177
|
-
if
|
|
234
|
+
if i == 0:
|
|
178
235
|
view_arg_str = ", " + grad_arg
|
|
179
236
|
break
|
|
180
237
|
return view_arg_str
|
|
@@ -23,7 +23,7 @@ from common.template import Template
|
|
|
23
23
|
import common.gen_constants as K
|
|
24
24
|
from common.gen_utils import save_file
|
|
25
25
|
from common.base_generator import BaseGenerator
|
|
26
|
-
from pyboost.pyboost_utils import is_optional_param, get_input_dtype
|
|
26
|
+
from pyboost.pyboost_utils import is_optional_param, get_input_dtype, get_output_dtype
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class AutoGradRegHeaderGenerator(BaseGenerator):
|
|
@@ -42,6 +42,9 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
|
|
|
42
42
|
self.op_grad_func_args_template = Template(
|
|
43
43
|
"const kernel::pyboost::OpPtr &, ${input_tensor_prt_args}"
|
|
44
44
|
)
|
|
45
|
+
self.op_view_grad_func_args_template = Template(
|
|
46
|
+
"${output_tensor_prt_args}, ${input_tensor_prt_args}"
|
|
47
|
+
)
|
|
45
48
|
|
|
46
49
|
def generate(self, work_path, op_protos):
|
|
47
50
|
"""
|
|
@@ -60,9 +63,13 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
|
|
|
60
63
|
continue
|
|
61
64
|
op_type_enum_list.append(self.op_type_enum_template.replace(class_name=op_proto.op_class.name,
|
|
62
65
|
enum_val=index))
|
|
66
|
+
# the backward func of flatten_ext and t_ext are implemented by other view ops, just continue
|
|
67
|
+
if op_proto.op_view and not op_proto.bprop_expander:
|
|
68
|
+
continue
|
|
63
69
|
grad_func_args_with_type_str = self._get_grad_func_args_with_type_str(op_proto)
|
|
64
|
-
op_grad_func_list.append(
|
|
65
|
-
|
|
70
|
+
op_grad_func_list.append(
|
|
71
|
+
self.op_grad_func_template.replace(class_name=op_proto.op_class.name,
|
|
72
|
+
grad_func_args=grad_func_args_with_type_str))
|
|
66
73
|
op_grad_func_obj_list.append(self.op_grad_func_obj_template.replace(class_name=op_proto.op_class.name))
|
|
67
74
|
index += 1
|
|
68
75
|
|
|
@@ -89,5 +96,15 @@ class AutoGradRegHeaderGenerator(BaseGenerator):
|
|
|
89
96
|
is_optional = is_optional_param(op_arg)
|
|
90
97
|
input_dtype = get_input_dtype(op_arg.arg_dtype, is_optional, op_proto.op_view)
|
|
91
98
|
input_tensor_prt_args_str += f"const {input_dtype} &, "
|
|
92
|
-
|
|
93
|
-
|
|
99
|
+
input_tensor_prt_args_str = input_tensor_prt_args_str.rstrip(', ')
|
|
100
|
+
if not op_proto.op_view:
|
|
101
|
+
return self.op_grad_func_args_template.replace(input_tensor_prt_args=\
|
|
102
|
+
input_tensor_prt_args_str)
|
|
103
|
+
# for view operators, the output is tensor or vector<tensor>
|
|
104
|
+
if len(op_proto.op_returns) != 1:
|
|
105
|
+
raise ValueError(f"the output of {op_proto.op_name} is not tensor,",
|
|
106
|
+
"tuple[tensor] or list[tensor], which is not not as expected")
|
|
107
|
+
output_dtype = get_output_dtype(op_proto.op_returns[0].arg_dtype)
|
|
108
|
+
output_tensor_prt_args_str = f"const {output_dtype} &"
|
|
109
|
+
return self.op_view_grad_func_args_template.replace(input_tensor_prt_args=input_tensor_prt_args_str,
|
|
110
|
+
output_tensor_prt_args=output_tensor_prt_args_str)
|
|
@@ -16,9 +16,6 @@
|
|
|
16
16
|
Generate pyboost function from pyboost_op.yaml
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
|
-
import os
|
|
20
|
-
import shutil
|
|
21
|
-
import logging
|
|
22
19
|
from resources.resource_list import ResourceType
|
|
23
20
|
from common import gen_constants as K
|
|
24
21
|
from api.functions_cc_generator import FunctionsGenerator, FunctionsHeaderGenerator
|
|
@@ -48,18 +45,6 @@ from .auto_grad_impl_cc_generator import AutoGradImplGenerator
|
|
|
48
45
|
from .auto_grad_reg_cc_generator import AutoGradRegHeaderGenerator
|
|
49
46
|
|
|
50
47
|
|
|
51
|
-
def clear_old_generated_code(work_path):
|
|
52
|
-
""" delete old generated files to prevent compilation failure """
|
|
53
|
-
files_to_clear = ['mindspore/ops/kernel/common/pyboost',
|
|
54
|
-
'mindspore/ops/kernel/functions/auto_generate',
|
|
55
|
-
'mindspore/ccsrc/runtime/pynative/op_function']
|
|
56
|
-
for f in files_to_clear:
|
|
57
|
-
real_path = os.path.join(work_path, f)
|
|
58
|
-
if os.path.exists(real_path):
|
|
59
|
-
shutil.rmtree(real_path)
|
|
60
|
-
logging.warning("rm file %s", real_path)
|
|
61
|
-
|
|
62
|
-
|
|
63
48
|
def gen_pyboost_code(resource_mgr):
|
|
64
49
|
""" gen_pyboost_code """
|
|
65
50
|
work_path = K.WORK_DIR
|
|
@@ -67,7 +52,6 @@ def gen_pyboost_code(resource_mgr):
|
|
|
67
52
|
doc_yaml_data = resource_mgr.get_resource(ResourceType.OP_DOC_YAML)
|
|
68
53
|
mint_func_protos = resource_mgr.get_resource(ResourceType.MINT_FUNC_PROTOS)
|
|
69
54
|
alias_func_mapping = resource_mgr.get_resource(ResourceType.ALIAS_API_MAPPING)
|
|
70
|
-
clear_old_generated_code(work_path)
|
|
71
55
|
call_pyboost_inner_prim_generator(work_path, op_protos)
|
|
72
56
|
call_pyboost_functions_py_generator(work_path, op_protos, doc_yaml_data)
|
|
73
57
|
call_pyboost_functions_h_generator(work_path, op_protos)
|
|
@@ -47,14 +47,15 @@ class OpTemplateParser:
|
|
|
47
47
|
self.op_proto = op_proto
|
|
48
48
|
self.tensor_arg_handler_prt_template = Template(
|
|
49
49
|
"parse_args.arg_list_[${idx}] = "
|
|
50
|
-
"
|
|
50
|
+
"PyLong_FromLong((*pynative::${func_str}(\"${func_name}\", \"${op_arg_name}\", "
|
|
51
51
|
"parse_args.arg_list_[${idx}]))->value());\n"
|
|
52
52
|
"parse_args.src_types_[${idx}] = ops::OP_DTYPE::DT_BEGIN;\n"
|
|
53
53
|
"parse_args.dst_types_[${idx}] = ${new_type};\n"
|
|
54
54
|
)
|
|
55
55
|
self.function_arg_handler_prt_template = Template(
|
|
56
56
|
"parse_args.arg_list_[${idx}] = "
|
|
57
|
-
"
|
|
57
|
+
"PyLong_FromLong((*${func_str}(\"${func_name}\", \"${op_arg_name}\", "
|
|
58
|
+
"parse_args.arg_list_[${idx}]))->value());\n"
|
|
58
59
|
"parse_args.src_types_[${idx}] = ops::OP_DTYPE::DT_BEGIN;\n"
|
|
59
60
|
"parse_args.dst_types_[${idx}] = ${new_type};\n"
|
|
60
61
|
)
|