mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Op Proto module for defining operator prototypes and their arguments."""
|
|
17
|
+
|
|
18
|
+
import os
|
|
19
|
+
from typing import Dict
|
|
20
|
+
|
|
21
|
+
from resources.resource_loader import ResourceLoader
|
|
22
|
+
from resources.resource_list import ResourceType
|
|
23
|
+
|
|
24
|
+
from . import gen_constants as K
|
|
25
|
+
from .gen_utils import safe_load_yaml_from_dir
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class OpArg:
|
|
29
|
+
"""
|
|
30
|
+
Represents an argument of an operator.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
arg_name (str): The name of the argument.
|
|
34
|
+
arg_dtype (str): The data type of the argument.
|
|
35
|
+
type_cast (list): A list of type casts applicable to the argument.
|
|
36
|
+
is_type_id (bool): Indicates if the argument is a type identifier.
|
|
37
|
+
as_init_arg (bool): Indicates if the argument is an initialization argument.
|
|
38
|
+
default: The default value of the argument.
|
|
39
|
+
inplace (str): The name of the inplace tensor if applicable.
|
|
40
|
+
is_prim_init (bool): Indicates if the argument is a primitive initialization argument.
|
|
41
|
+
arg_handler (str): A handler for the argument, if applicable.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, arg_name, arg_dtype, type_cast, is_type_id=False, as_init_arg=False, default=-1, inplace='',
|
|
45
|
+
is_prim_init=False, arg_handler=''):
|
|
46
|
+
self.arg_name = arg_name
|
|
47
|
+
self.arg_dtype = arg_dtype
|
|
48
|
+
self.type_cast = type_cast
|
|
49
|
+
self.is_type_id = is_type_id
|
|
50
|
+
self.as_init_arg = as_init_arg
|
|
51
|
+
self.default = default
|
|
52
|
+
self.inplace = inplace
|
|
53
|
+
self.is_prim_init = is_prim_init
|
|
54
|
+
self.arg_handler = arg_handler
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class OpArgsSignature:
|
|
58
|
+
"""
|
|
59
|
+
Represents the signature of operator arguments.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
rw_write (list): Arguments that are written to.
|
|
63
|
+
rw_read (list): Arguments that are read from.
|
|
64
|
+
rw_ref (list): Arguments that are passed by reference.
|
|
65
|
+
dtype_group (list): Grouping of data types for the arguments.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, rw_write=None, rw_read=None, rw_ref=None, dtype_group=None):
|
|
69
|
+
self.rw_write = rw_write
|
|
70
|
+
self.rw_read = rw_read
|
|
71
|
+
self.rw_ref = rw_ref
|
|
72
|
+
self.dtype_group = dtype_group
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class OpFunction:
|
|
76
|
+
"""
|
|
77
|
+
Represents the function associated with an operator.
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
disable (bool): Indicates if the function is disabled.
|
|
81
|
+
name (str): The name of the function.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, disable=False, name=''):
|
|
85
|
+
self.disable = disable
|
|
86
|
+
self.name = name
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class OpClass:
|
|
90
|
+
"""
|
|
91
|
+
Represents a class associated with an operator.
|
|
92
|
+
|
|
93
|
+
Attributes:
|
|
94
|
+
disable (bool): Indicates if the class is disabled.
|
|
95
|
+
name (str): The name of the class.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(self, disable=False, name=''):
|
|
99
|
+
self.disable = disable
|
|
100
|
+
self.name = name
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class OpDispatch:
|
|
104
|
+
"""
|
|
105
|
+
Represents the dispatch information for an operator.
|
|
106
|
+
|
|
107
|
+
Attributes:
|
|
108
|
+
enable (bool): Indicates if the dispatch is enabled.
|
|
109
|
+
is_comm_op (bool): Indicates if the dispatch is communication operator or not.
|
|
110
|
+
ascend (str): The dispatch type for the Ascend device.
|
|
111
|
+
cpu (str): The dispatch type for the CPU.
|
|
112
|
+
gpu (str): The dispatch type for the GPU.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, enable=False, is_comm_op=False, ascend='default', cpu='default', gpu='default'):
|
|
116
|
+
self.enable = enable
|
|
117
|
+
self.is_comm_op = is_comm_op
|
|
118
|
+
self.ascend = ascend
|
|
119
|
+
self.cpu = cpu
|
|
120
|
+
self.gpu = gpu
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class OpProto:
|
|
124
|
+
"""
|
|
125
|
+
Defines a prototype for an operator in MindSpore.
|
|
126
|
+
|
|
127
|
+
This class is used to parse the operator definition from a YAML file and to generate
|
|
128
|
+
the necessary primitive and PyBoost functions.
|
|
129
|
+
|
|
130
|
+
Attributes:
|
|
131
|
+
op_name (str): The name of the operator.
|
|
132
|
+
op_args (list): A list of arguments for the operator.
|
|
133
|
+
op_function (OpFunction): The function associated with the operator.
|
|
134
|
+
op_class (OpClass): The class associated with the operator.
|
|
135
|
+
op_dispatch (OpDispatch): The dispatch information for the operator.
|
|
136
|
+
op_args_signature (OpArgsSignature): The signature of the operator's arguments.
|
|
137
|
+
op_returns (list): A list of return values for the operator.
|
|
138
|
+
op_view (bool): Indicates if the operator is a view operator.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self,
|
|
142
|
+
op_name,
|
|
143
|
+
op_args,
|
|
144
|
+
op_function,
|
|
145
|
+
op_class,
|
|
146
|
+
op_dispatch,
|
|
147
|
+
op_args_signature,
|
|
148
|
+
op_returns,
|
|
149
|
+
op_view=False,
|
|
150
|
+
op_graph_view=False,
|
|
151
|
+
op_inplace=False,
|
|
152
|
+
op_labels=None,
|
|
153
|
+
op_deprecated=None,
|
|
154
|
+
bprop_expander=True):
|
|
155
|
+
self.op_name = op_name
|
|
156
|
+
self.op_args = op_args
|
|
157
|
+
self.op_function = op_function
|
|
158
|
+
self.op_class = op_class
|
|
159
|
+
self.op_dispatch = op_dispatch
|
|
160
|
+
self.op_args_signature = op_args_signature
|
|
161
|
+
self.op_returns = op_returns
|
|
162
|
+
self.op_view = op_view
|
|
163
|
+
self.op_graph_view = op_graph_view
|
|
164
|
+
self.op_inplace = op_inplace
|
|
165
|
+
self.op_labels = op_labels
|
|
166
|
+
self.op_deprecated = op_deprecated
|
|
167
|
+
self.bprop_expander = bprop_expander
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def load_from_yaml(op_name, op_data):
|
|
171
|
+
"""
|
|
172
|
+
Loads an operator prototype from YAML data.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
op_name (str): The name of the operation.
|
|
176
|
+
op_data (dict): A dictionary containing the operation data.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
OpProto: An instance of OpProto representing the operator.
|
|
180
|
+
"""
|
|
181
|
+
# check op keys
|
|
182
|
+
check_validation(op_name, op_data)
|
|
183
|
+
# get op args
|
|
184
|
+
op_args = get_op_args(op_name, op_data)
|
|
185
|
+
# get op return args
|
|
186
|
+
op_returns = get_op_returns(op_name, op_data)
|
|
187
|
+
# get op args signature
|
|
188
|
+
op_args_signature = get_op_args_signature(op_name, op_data)
|
|
189
|
+
# get op class
|
|
190
|
+
op_class = get_op_class(op_name, op_data)
|
|
191
|
+
# get op function
|
|
192
|
+
op_function = get_op_function(op_name, op_data)
|
|
193
|
+
# get op dispatch
|
|
194
|
+
op_dispatch = get_op_dispatch(op_name, op_data)
|
|
195
|
+
# get op view
|
|
196
|
+
op_view = op_data.get('view', False)
|
|
197
|
+
if not isinstance(op_view, bool):
|
|
198
|
+
raise TypeError(
|
|
199
|
+
f'The view value should be bool, but get {type(op_view)}, op name is {op_name}.')
|
|
200
|
+
# get op graph view
|
|
201
|
+
op_graph_view = op_data.get('graph_view', False)
|
|
202
|
+
if not isinstance(op_graph_view, bool):
|
|
203
|
+
raise TypeError(
|
|
204
|
+
f'The graph view value should be bool, but get {type(op_graph_view)}, op name is {op_name}.')
|
|
205
|
+
op_inplace = is_inplace_op(op_returns)
|
|
206
|
+
# get op labels
|
|
207
|
+
op_labels = op_data.get('labels', None)
|
|
208
|
+
# get op deprecated
|
|
209
|
+
op_deprecated = op_data.get('deprecated', None)
|
|
210
|
+
bprop_expander = op_data.get('bprop_expander', True)
|
|
211
|
+
op_proto = OpProto(op_name=op_name, op_args=op_args, op_returns=op_returns, op_function=op_function,
|
|
212
|
+
op_class=op_class, op_dispatch=op_dispatch, op_args_signature=op_args_signature,
|
|
213
|
+
op_view=op_view, op_graph_view=op_graph_view, op_inplace=op_inplace, op_labels=op_labels,
|
|
214
|
+
op_deprecated=op_deprecated, bprop_expander=bprop_expander)
|
|
215
|
+
return op_proto
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class OpProtoLoader(ResourceLoader):
|
|
219
|
+
"""
|
|
220
|
+
OpProtoLoader is a class for loading operator prototypes from YAML data.
|
|
221
|
+
"""
|
|
222
|
+
def __init__(self):
|
|
223
|
+
ops_yaml_path = os.path.join(K.WORK_DIR, K.MS_OP_DEF_YAML_PATH)
|
|
224
|
+
infer_ops_yaml_path = os.path.join(ops_yaml_path, 'infer')
|
|
225
|
+
self.yaml_paths = [ops_yaml_path, infer_ops_yaml_path]
|
|
226
|
+
self.type = ResourceType.OP_PROTO
|
|
227
|
+
self.is_deprecated = False
|
|
228
|
+
self.func_op = False
|
|
229
|
+
|
|
230
|
+
def load(self) -> Dict[ResourceType, object]:
|
|
231
|
+
"""
|
|
232
|
+
Load OpProto.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Dict[ResourceType, object]: The resource type and the OpProto.
|
|
236
|
+
"""
|
|
237
|
+
yaml_dict = {}
|
|
238
|
+
for yaml_path in self.yaml_paths:
|
|
239
|
+
yaml_dict.update(safe_load_yaml_from_dir(yaml_path))
|
|
240
|
+
op_protos = []
|
|
241
|
+
for op_name, op_data in yaml_dict.items():
|
|
242
|
+
op_proto = OpProto.load_from_yaml(op_name, op_data)
|
|
243
|
+
if self.is_deprecated:
|
|
244
|
+
op_proto.op_name = 'deprecated_' + op_name
|
|
245
|
+
op_proto.func_op = self.func_op
|
|
246
|
+
op_protos.append(op_proto)
|
|
247
|
+
return {self.type: op_protos}
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class DeprecatedOpProtoLoader(OpProtoLoader):
|
|
251
|
+
"""
|
|
252
|
+
DeprecatedOpProtoLoader is a class for loading deprecated operator prototypes from YAML data.
|
|
253
|
+
"""
|
|
254
|
+
def __init__(self):
|
|
255
|
+
super().__init__()
|
|
256
|
+
self.yaml_paths = [os.path.join(K.WORK_DIR, K.MS_OP_DEPRECATED_DEF_YAML_PATH)]
|
|
257
|
+
self.type = ResourceType.DEPRECATED_OP_PROTO
|
|
258
|
+
self.is_deprecated = True
|
|
259
|
+
self.func_op = True
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class FuncOpProtoLoader(OpProtoLoader):
|
|
263
|
+
"""
|
|
264
|
+
FuncOpProtoLoader is a class for loading func_op operator prototypes from YAML data.
|
|
265
|
+
"""
|
|
266
|
+
def __init__(self):
|
|
267
|
+
super().__init__()
|
|
268
|
+
self.yaml_paths = [os.path.join(K.WORK_DIR, K.MS_OP_DEF_FUNC_OP_YAML_PATH)]
|
|
269
|
+
self.type = ResourceType.FUNC_OP_PROTO
|
|
270
|
+
self.is_deprecated = False
|
|
271
|
+
self.func_op = True
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def get_op_args_signature(op_name, op_data):
|
|
275
|
+
"""
|
|
276
|
+
Retrieves the argument signature from the operation data.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
op_data (dict): A dictionary containing the operation data.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
OpArgsSignature: An instance of OpArgsSignature containing the argument signature.
|
|
283
|
+
"""
|
|
284
|
+
op_args_signature = op_data.get('args_signature', None)
|
|
285
|
+
if op_args_signature is not None:
|
|
286
|
+
args_signature_keys = op_args_signature.keys()
|
|
287
|
+
check_op_yaml_keys(op_name, set(args_signature_keys), K.ARG_SIGNATURE_KEYS)
|
|
288
|
+
rw_write = op_args_signature.get('rw_write', None)
|
|
289
|
+
rw_read = op_args_signature.get('rw_read', None)
|
|
290
|
+
rw_ref = op_args_signature.get('rw_ref', None)
|
|
291
|
+
dtype_group = op_args_signature.get('dtype_group', None)
|
|
292
|
+
return OpArgsSignature(rw_write, rw_read, rw_ref, dtype_group)
|
|
293
|
+
return None
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def check_validation(op_name: str, op_data: dict):
|
|
297
|
+
"""
|
|
298
|
+
Validates the operator data to ensure it contains necessary keys.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
op_data (dict): The operator data to validate.
|
|
302
|
+
|
|
303
|
+
Raises:
|
|
304
|
+
TypeError: If the required keys 'args' or 'returns' are missing.
|
|
305
|
+
"""
|
|
306
|
+
# check keys
|
|
307
|
+
check_op_yaml_keys(op_name, set(op_data.keys()), K.OP_KEYS)
|
|
308
|
+
|
|
309
|
+
# Those keys must in yaml
|
|
310
|
+
if 'args' not in op_data.keys():
|
|
311
|
+
raise TypeError(f"Op define miss key 'args', op name is {op_name}")
|
|
312
|
+
if 'returns' not in op_data.keys():
|
|
313
|
+
raise TypeError(f"Op define miss key 'returns', op name is {op_name}")
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def get_op_args(op_name, op_data):
|
|
317
|
+
"""
|
|
318
|
+
Retrieves the arguments for the operator from the operation data.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
op_data (dict): A dictionary containing the operation data.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
list: A list of OpArg instances representing the arguments of the operator.
|
|
325
|
+
"""
|
|
326
|
+
args_dict = op_data.get('args')
|
|
327
|
+
op_args = []
|
|
328
|
+
for arg_name in args_dict.keys():
|
|
329
|
+
arg_keys = args_dict[arg_name].keys()
|
|
330
|
+
check_op_yaml_keys(op_name, set(arg_keys), K.ARG_KEYS)
|
|
331
|
+
arg_dtype = args_dict[arg_name]['dtype']
|
|
332
|
+
if arg_dtype == 'TypeId':
|
|
333
|
+
arg_dtype = 'int'
|
|
334
|
+
default = None
|
|
335
|
+
as_init_arg = False
|
|
336
|
+
is_type_id = False
|
|
337
|
+
prim_init = False
|
|
338
|
+
type_cast = []
|
|
339
|
+
if 'default' in args_dict[arg_name]:
|
|
340
|
+
default = args_dict[arg_name]['default']
|
|
341
|
+
as_init_arg = True
|
|
342
|
+
# 当op_args任意一个参数有prim_init,该op就要在pyboost_inner_prim.py生成
|
|
343
|
+
if 'prim_init' in args_dict[arg_name] and args_dict[arg_name]['prim_init'] is True:
|
|
344
|
+
prim_init = True
|
|
345
|
+
if 'type_cast' in args_dict[arg_name]:
|
|
346
|
+
type_cast = [cast_type.strip() for cast_type in args_dict[arg_name]['type_cast'].split(',')]
|
|
347
|
+
arg_handler_key = 'arg_handler'
|
|
348
|
+
arg_handler = args_dict[arg_name].get(arg_handler_key, '')
|
|
349
|
+
if arg_handler_key in args_dict[arg_name] and args_dict[arg_name][arg_handler_key] == 'dtype_to_type_id':
|
|
350
|
+
is_type_id = True
|
|
351
|
+
op_arg = OpArg(arg_name, arg_dtype, type_cast, is_type_id, as_init_arg, default,
|
|
352
|
+
is_prim_init=prim_init, arg_handler=arg_handler)
|
|
353
|
+
op_args.append(op_arg)
|
|
354
|
+
return op_args
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def get_op_returns(op_name, op_data):
|
|
358
|
+
"""
|
|
359
|
+
Retrieves the return values for the operator from the operation data.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
op_data (dict): A dictionary containing the operation data.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
list: A list of OpArg instances representing the return values of the operator.
|
|
366
|
+
"""
|
|
367
|
+
op_return_args = []
|
|
368
|
+
return_dict = op_data['returns']
|
|
369
|
+
for return_name in return_dict.keys():
|
|
370
|
+
return_keys = return_dict[return_name].keys()
|
|
371
|
+
check_op_yaml_keys(op_name, set(return_keys), K.RETURN_KEYS)
|
|
372
|
+
inplace = ''
|
|
373
|
+
if 'inplace' in return_dict[return_name]:
|
|
374
|
+
inplace = return_dict[return_name]['inplace']
|
|
375
|
+
if 'dtype' not in return_dict[return_name]:
|
|
376
|
+
raise TypeError("op return args need key 'dtype'")
|
|
377
|
+
dtype = return_dict[return_name]['dtype']
|
|
378
|
+
arg = OpArg(return_name, dtype, type_cast=[], inplace=inplace)
|
|
379
|
+
op_return_args.append(arg)
|
|
380
|
+
return op_return_args
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def get_op_dispatch(op_name, op_data):
|
|
384
|
+
"""
|
|
385
|
+
Retrieves the dispatch information for the operator from the operation data.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
op_data (dict): A dictionary containing the operation data.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
OpDispatch: An instance of OpDispatch containing the dispatch information.
|
|
392
|
+
"""
|
|
393
|
+
op_dispatch = op_data.get('dispatch', {})
|
|
394
|
+
dispatch_keys = op_dispatch.keys()
|
|
395
|
+
check_op_yaml_keys(op_name, set(dispatch_keys), K.DISPATCH_KEYS)
|
|
396
|
+
if not op_dispatch:
|
|
397
|
+
return None
|
|
398
|
+
enable = op_dispatch.get('enable', False)
|
|
399
|
+
if not isinstance(enable, bool):
|
|
400
|
+
raise TypeError(
|
|
401
|
+
f'The dispatch enable value should be bool, but get {type(enable)}, op name is {op_name}.')
|
|
402
|
+
is_comm_op = op_dispatch.get('is_comm_op', False)
|
|
403
|
+
ascend = op_dispatch.get('Ascend', 'default')
|
|
404
|
+
cpu = op_dispatch.get('CPU', 'default')
|
|
405
|
+
gpu = op_dispatch.get('GPU', 'default')
|
|
406
|
+
return OpDispatch(enable, is_comm_op, ascend, cpu, gpu)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def get_op_class(op_name, op_data) -> OpClass:
|
|
410
|
+
"""
|
|
411
|
+
Retrieves the class information for the operator from the operation data.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
op_name (str): The name of the operation.
|
|
415
|
+
op_data (dict): A dictionary containing the operation data.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
OpClass: An instance of OpClass containing the class information for the operator.
|
|
419
|
+
"""
|
|
420
|
+
op_class = op_data.get('class', {})
|
|
421
|
+
class_keys = op_class.keys()
|
|
422
|
+
check_op_yaml_keys(op_name, set(class_keys), K.CLASS_KEYS)
|
|
423
|
+
is_disable = op_class.get('disable', False)
|
|
424
|
+
if not isinstance(is_disable, bool):
|
|
425
|
+
raise TypeError(
|
|
426
|
+
f'The class disable value should be bool, but get {type(is_disable)}, op name is {op_name}.')
|
|
427
|
+
class_name = op_class.get('name', convert_python_func_name_to_c(op_name))
|
|
428
|
+
return OpClass(disable=is_disable, name=class_name)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def get_op_function(op_name, op_data) -> OpFunction:
|
|
432
|
+
"""
|
|
433
|
+
Retrieves the function information for the operator from the operation data.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
op_name (str): default operation function name.
|
|
437
|
+
op_data (dict): A dictionary containing the operation data.
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
OpFunction: An instance of OpFunction containing the function information for the operator.
|
|
441
|
+
"""
|
|
442
|
+
op_function = op_data.get('function', {})
|
|
443
|
+
function_keys = op_function.keys()
|
|
444
|
+
check_op_yaml_keys(op_name, set(function_keys), K.FUNCTION_KEYS)
|
|
445
|
+
is_disable = op_function.get('disable', False)
|
|
446
|
+
if not isinstance(is_disable, bool):
|
|
447
|
+
raise TypeError(
|
|
448
|
+
f'The function disable value should be bool, but get {type(is_disable)}, op name is {op_name}.')
|
|
449
|
+
function_name = op_function.get('name', op_name)
|
|
450
|
+
return OpFunction(disable=is_disable, name=function_name)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def convert_python_func_name_to_c(func_name: str) -> str:
|
|
454
|
+
return ''.join(word.capitalize() for word in func_name.split('_'))
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def check_op_yaml_keys(op_name: str, input_keys: set, compare_keys: set):
|
|
458
|
+
diff_keys = input_keys - compare_keys
|
|
459
|
+
if diff_keys:
|
|
460
|
+
raise TypeError(
|
|
461
|
+
f'The definition of keys in yaml has faults, op name is {op_name}, wrong keys are {diff_keys}.')
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def is_inplace_op(args):
|
|
465
|
+
"""
|
|
466
|
+
is inplace op
|
|
467
|
+
:param args:
|
|
468
|
+
:return: bool
|
|
469
|
+
"""
|
|
470
|
+
for arg in args:
|
|
471
|
+
if arg.inplace:
|
|
472
|
+
return True
|
|
473
|
+
return False
|