mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -20,14 +20,133 @@ import mindspore as ms
|
|
|
20
20
|
from mindspore import ops
|
|
21
21
|
from mindspore.common.tensor import Tensor
|
|
22
22
|
from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple
|
|
23
|
-
from mindspore.ops_generate.gen_ops_inner_prim import TupleToList, ListToTuple
|
|
24
23
|
from mindspore._c_expression import OpDtype
|
|
24
|
+
from mindspore._c_expression import typing
|
|
25
|
+
from mindspore._c_expression import op_enum
|
|
26
|
+
from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register
|
|
25
27
|
|
|
26
28
|
tensor_to_tuple_ = TensorToTuple()
|
|
29
|
+
tensor_to_scalar_ = TensorToScalar()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TupleToList(Primitive):
|
|
33
|
+
r"""
|
|
34
|
+
Convert tuple to list.
|
|
35
|
+
|
|
36
|
+
Inputs:
|
|
37
|
+
- **x** (tuple) - The input
|
|
38
|
+
|
|
39
|
+
Outputs:
|
|
40
|
+
List, has the same elements as the `input`.
|
|
41
|
+
|
|
42
|
+
Supported Platforms:
|
|
43
|
+
``CPU``
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
>>> from mindspore.ops._utils.arg_dtype_cast import TupleToList
|
|
47
|
+
>>> x = (1, 2, 3)
|
|
48
|
+
>>> result = TupleToList()(x)
|
|
49
|
+
>>> print(result)
|
|
50
|
+
[1, 2, 3]
|
|
51
|
+
"""
|
|
52
|
+
@prim_arg_register
|
|
53
|
+
def __init__(self):
|
|
54
|
+
"""Initialize TupleToList"""
|
|
55
|
+
|
|
56
|
+
def __call__(self, input):
|
|
57
|
+
return list(input)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ListToTuple(Primitive):
|
|
61
|
+
r"""
|
|
62
|
+
Convert list to tuple.
|
|
63
|
+
|
|
64
|
+
Inputs:
|
|
65
|
+
- **x** (list) - The input
|
|
66
|
+
|
|
67
|
+
Outputs:
|
|
68
|
+
Tuple, has the same elements as the `input`.
|
|
69
|
+
|
|
70
|
+
Supported Platforms:
|
|
71
|
+
``CPU``
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
>>> from mindspore.ops._utils.arg_dtype_cast import ListToTuple
|
|
75
|
+
>>> x = [1, 2, 3]
|
|
76
|
+
>>> result = ListToTuple()(x)
|
|
77
|
+
>>> print(result)
|
|
78
|
+
(1, 2, 3)
|
|
79
|
+
"""
|
|
80
|
+
@prim_arg_register
|
|
81
|
+
def __init__(self):
|
|
82
|
+
"""Initialize TupleToList"""
|
|
83
|
+
|
|
84
|
+
def __call__(self, input):
|
|
85
|
+
return tuple(input)
|
|
86
|
+
|
|
87
|
+
|
|
27
88
|
tuple_to_list = TupleToList()
|
|
28
89
|
list_to_tuple = ListToTuple()
|
|
29
90
|
|
|
30
91
|
|
|
92
|
+
class DtypeToEnum(Primitive):
|
|
93
|
+
r"""
|
|
94
|
+
Convert mindspore dtype to enum.
|
|
95
|
+
|
|
96
|
+
Inputs:
|
|
97
|
+
- **op_name** (str) - The op name
|
|
98
|
+
- **arg_name** (str) - The arg name
|
|
99
|
+
- **dtype** (mindspore.dtype) - The data type.
|
|
100
|
+
|
|
101
|
+
Outputs:
|
|
102
|
+
An integer.
|
|
103
|
+
|
|
104
|
+
Supported Platforms:
|
|
105
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
@prim_attr_register
|
|
109
|
+
def __init__(self):
|
|
110
|
+
"""Initialize"""
|
|
111
|
+
|
|
112
|
+
def __call__(self, op_name, arg_name, dtype):
|
|
113
|
+
"""Run in PyNative mode"""
|
|
114
|
+
if not isinstance(dtype, typing.Type):
|
|
115
|
+
raise TypeError(
|
|
116
|
+
f"For '{op_name}', the input '{arg_name}' should be mindspore dtype, but got {dtype}.")
|
|
117
|
+
return typing.type_to_type_id(dtype)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class StringToEnum(Primitive):
|
|
121
|
+
r"""
|
|
122
|
+
Convert string to enum.
|
|
123
|
+
|
|
124
|
+
Inputs:
|
|
125
|
+
- **op_name** (str) - The op name
|
|
126
|
+
- **arg_name** (str) - The arg name
|
|
127
|
+
- **enum_str** (str) - The str data.
|
|
128
|
+
|
|
129
|
+
Outputs:
|
|
130
|
+
An integer.
|
|
131
|
+
|
|
132
|
+
Supported Platforms:
|
|
133
|
+
``CPU``
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
@prim_attr_register
|
|
137
|
+
def __init__(self):
|
|
138
|
+
"""Initialize"""
|
|
139
|
+
|
|
140
|
+
def __call__(self, op_name, arg_name, enum_str):
|
|
141
|
+
"""Run in PyNative mode"""
|
|
142
|
+
if enum_str is None:
|
|
143
|
+
return None
|
|
144
|
+
if not isinstance(enum_str, str):
|
|
145
|
+
raise TypeError(
|
|
146
|
+
f"For '{op_name}', the input '{arg_name}' should be a str, but got {type(enum_str)}.")
|
|
147
|
+
return op_enum.str_to_enum(op_name, arg_name, enum_str)
|
|
148
|
+
|
|
149
|
+
|
|
31
150
|
def int_to_float(data):
|
|
32
151
|
return float(data)
|
|
33
152
|
|
|
@@ -184,7 +303,7 @@ def get_support_dtype_list(src_type, dst_type):
|
|
|
184
303
|
return support_list
|
|
185
304
|
|
|
186
305
|
|
|
187
|
-
def
|
|
306
|
+
def tensor_to_number(data, dst_type):
|
|
188
307
|
"""Convert tensor to python number"""
|
|
189
308
|
if dst_type == DT_INT_VAL:
|
|
190
309
|
data = ops.cast(data, ms.int64)
|
|
@@ -197,7 +316,7 @@ def to_py_number(data, dst_type):
|
|
|
197
316
|
data = ops.cast(data, ms.int64)
|
|
198
317
|
elif src_type in (ms.bfloat16, ms.float16, ms.float32, ms.float64):
|
|
199
318
|
data = ops.cast(data, ms.float32)
|
|
200
|
-
return
|
|
319
|
+
return tensor_to_scalar_(data)
|
|
201
320
|
|
|
202
321
|
|
|
203
322
|
def do_type_cast(data, dst_type):
|
|
@@ -230,7 +349,7 @@ def do_type_cast(data, dst_type):
|
|
|
230
349
|
return list_to_tensor(data)
|
|
231
350
|
elif is_number(dst_type):
|
|
232
351
|
if isinstance(data, Tensor):
|
|
233
|
-
return
|
|
352
|
+
return tensor_to_number(data, dst_type)
|
|
234
353
|
raise TypeError("Type conversion failed.")
|
|
235
354
|
|
|
236
355
|
|
|
@@ -14,13 +14,11 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Operator argument handle function."""
|
|
16
16
|
|
|
17
|
-
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
|
|
18
|
-
# Enum Class:
|
|
19
|
-
from mindspore._c_expression import FormatEnum as Format
|
|
20
|
-
from mindspore._c_expression import ReductionEnum as Reduction
|
|
21
17
|
from mindspore.common import Tensor
|
|
22
18
|
from mindspore.common import dtype as mstype
|
|
23
19
|
|
|
20
|
+
from .arg_dtype_cast import DtypeToEnum, StringToEnum
|
|
21
|
+
|
|
24
22
|
|
|
25
23
|
def arg_invalid_info(op_name, arg_name, arg_val):
|
|
26
24
|
"""
|
|
@@ -116,67 +114,6 @@ def to_2d_paddings(op_name, arg_name, pad):
|
|
|
116
114
|
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
117
115
|
|
|
118
116
|
|
|
119
|
-
def to_paddings(op_name, arg_name, pad):
|
|
120
|
-
"""
|
|
121
|
-
convert paddings: int -> tuple[int*4].
|
|
122
|
-
"""
|
|
123
|
-
if isinstance(pad, int):
|
|
124
|
-
return (pad,) * 4
|
|
125
|
-
if isinstance(pad, (tuple, list)):
|
|
126
|
-
return pad
|
|
127
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def to_3d_kernel_size(op_name, arg_name, kernel_size):
|
|
131
|
-
"""
|
|
132
|
-
convert 3d kernel_size: int/tuple[int*6] -> tuple[int*3].
|
|
133
|
-
"""
|
|
134
|
-
if isinstance(kernel_size, int):
|
|
135
|
-
return (kernel_size, kernel_size, kernel_size)
|
|
136
|
-
if isinstance(kernel_size, (tuple, list)):
|
|
137
|
-
if len(kernel_size) == 5:
|
|
138
|
-
return (kernel_size[2], kernel_size[3], kernel_size[4])
|
|
139
|
-
return kernel_size
|
|
140
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size))
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def to_3d_strides(op_name, arg_name, stride):
|
|
144
|
-
"""
|
|
145
|
-
convert 3d stride: int/tuple[int*6] -> tuple[int*3].
|
|
146
|
-
"""
|
|
147
|
-
if isinstance(stride, int):
|
|
148
|
-
return (stride, stride, stride)
|
|
149
|
-
if isinstance(stride, (tuple, list)):
|
|
150
|
-
if len(stride) == 5:
|
|
151
|
-
return (stride[2], stride[3], stride[4])
|
|
152
|
-
return stride
|
|
153
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, stride))
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def to_3d_dilations(op_name, arg_name, dilation):
|
|
157
|
-
"""
|
|
158
|
-
convert 3d dilation: int/tuple[int*6] -> tuple[int*3].
|
|
159
|
-
"""
|
|
160
|
-
if isinstance(dilation, int):
|
|
161
|
-
return (dilation, dilation, dilation)
|
|
162
|
-
if isinstance(dilation, (tuple, list)):
|
|
163
|
-
if len(dilation) == 5:
|
|
164
|
-
return (dilation[2], dilation[3], dilation[4])
|
|
165
|
-
return dilation
|
|
166
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, dilation))
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def to_3d_paddings(op_name, arg_name, pad):
|
|
170
|
-
"""
|
|
171
|
-
convert 3d paddings: int -> tuple[int*6].
|
|
172
|
-
"""
|
|
173
|
-
if isinstance(pad, int):
|
|
174
|
-
return (pad,) * 6
|
|
175
|
-
if isinstance(pad, (tuple, list)):
|
|
176
|
-
return pad
|
|
177
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
178
|
-
|
|
179
|
-
|
|
180
117
|
def generator_handler(op_name, arg_name, inputs):
|
|
181
118
|
"""
|
|
182
119
|
convert constant value in tuple to tensor
|
|
@@ -189,6 +126,7 @@ def generator_handler(op_name, arg_name, inputs):
|
|
|
189
126
|
new_inputs.append(input_)
|
|
190
127
|
return tuple(new_inputs)
|
|
191
128
|
|
|
129
|
+
|
|
192
130
|
dtype_to_type_id = DtypeToEnum()
|
|
193
131
|
|
|
194
132
|
# string to enum
|
|
@@ -15,12 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
"""array_ops vmap impl."""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
|
+
from enum import Enum
|
|
18
19
|
|
|
19
20
|
import mindspore
|
|
20
21
|
import mindspore.numpy as mnp
|
|
21
22
|
from mindspore import ops
|
|
22
23
|
from mindspore.common import Tensor
|
|
23
|
-
from mindspore._c_expression import
|
|
24
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
24
25
|
from mindspore.ops import operations as P
|
|
25
26
|
from mindspore.ops import functional as F
|
|
26
27
|
from mindspore.ops.primitive import constexpr, _primexpr
|
|
@@ -140,6 +141,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
140
141
|
the generated prefix is a Tensor([[[0], [0]],
|
|
141
142
|
[[1], [1]]])
|
|
142
143
|
"""
|
|
144
|
+
cast_op = P.Cast()
|
|
145
|
+
|
|
143
146
|
def _check(indices_shape):
|
|
144
147
|
if not indices_shape:
|
|
145
148
|
raise ValueError("indices_shape is empty in _get_prefix.")
|
|
@@ -147,8 +150,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
147
150
|
_check(indices_shape)
|
|
148
151
|
indices_len = len(indices_shape)
|
|
149
152
|
if indices_len == 1:
|
|
150
|
-
prefix = P.Range()(
|
|
151
|
-
return prefix
|
|
153
|
+
prefix = P.Range()(0, axis_size, 1)
|
|
154
|
+
return cast_op(prefix, indices_dtype)
|
|
152
155
|
|
|
153
156
|
indices_end = indices_len - 1
|
|
154
157
|
prefix_shape = ()
|
|
@@ -163,9 +166,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
163
166
|
else:
|
|
164
167
|
expand_shape = expand_shape + (1,)
|
|
165
168
|
|
|
166
|
-
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(
|
|
167
|
-
|
|
168
|
-
return prefix
|
|
169
|
+
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(0, axis_size, 1), expand_shape))
|
|
170
|
+
return cast_op(prefix, indices_dtype)
|
|
169
171
|
|
|
170
172
|
|
|
171
173
|
@vmap_rules_getters.register(P.Transpose)
|
|
@@ -212,6 +214,31 @@ def get_transpose_vmap_rule(prim, axis_size):
|
|
|
212
214
|
return vmap_rule
|
|
213
215
|
|
|
214
216
|
|
|
217
|
+
@vmap_rules_getters.register("TransposeExtView")
|
|
218
|
+
def get_transpose_ext_vmap_rule(prim, axis_size):
|
|
219
|
+
"""VmapRule for `TransposeExtView` operation."""
|
|
220
|
+
if isinstance(prim, str):
|
|
221
|
+
prim = Primitive(prim)
|
|
222
|
+
|
|
223
|
+
def vmap_rule(x_bdim, dim1_bdim, dim2_bdim):
|
|
224
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim1_bdim, dim2_bdim)
|
|
225
|
+
if is_all_none:
|
|
226
|
+
return result
|
|
227
|
+
|
|
228
|
+
x, dim = x_bdim
|
|
229
|
+
dim1, dim1_dim = dim1_bdim
|
|
230
|
+
dim2, dim2_dim = dim2_bdim
|
|
231
|
+
if dim1_dim is not None or dim2_dim is not None:
|
|
232
|
+
_raise_value_error("The source axis of dim1_dim and dim2_dim in `TransposeExtView` must be None, "
|
|
233
|
+
"but got {} and {}.".format(dim1_dim, dim2_dim))
|
|
234
|
+
batch_dim1 = dim1 if dim1 < dim else dim1 + 1
|
|
235
|
+
batch_dim2 = dim2 if dim2 < dim else dim2 + 1
|
|
236
|
+
out = prim(x, batch_dim1, batch_dim2)
|
|
237
|
+
return out, dim
|
|
238
|
+
|
|
239
|
+
return vmap_rule
|
|
240
|
+
|
|
241
|
+
|
|
215
242
|
@vmap_rules_getters.register("Tile")
|
|
216
243
|
def get_tile_vmap_rule(prim, axis_size):
|
|
217
244
|
"""VmapRule for `P.Tile` operation."""
|
|
@@ -1488,16 +1515,19 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1488
1515
|
"""VmapRule for `P.Meshgrid` operation."""
|
|
1489
1516
|
if isinstance(prim, str):
|
|
1490
1517
|
prim = Primitive(prim)
|
|
1491
|
-
indexing = prim.indexing
|
|
1492
1518
|
|
|
1493
|
-
|
|
1494
|
-
|
|
1519
|
+
class Indexing(Enum):
|
|
1520
|
+
ij = 0
|
|
1521
|
+
xy = 1
|
|
1522
|
+
|
|
1523
|
+
def vmap_rule(inputs_bdim, indexing_bdim):
|
|
1524
|
+
is_all_none, result = vmap_general_preprocess(prim, inputs_bdim, indexing_bdim)
|
|
1495
1525
|
if is_all_none:
|
|
1496
1526
|
return result
|
|
1497
1527
|
|
|
1498
1528
|
if not isinstance(inputs_bdim, (tuple)):
|
|
1499
1529
|
_raise_value_error("The inputs of P.Meshgrid is not tuple.")
|
|
1500
|
-
args = inputs_bdim
|
|
1530
|
+
args = inputs_bdim
|
|
1501
1531
|
if len(args) <= 1:
|
|
1502
1532
|
_raise_value_error(
|
|
1503
1533
|
"The input number of P.Meshgrid must be greater than 1.")
|
|
@@ -1518,7 +1548,9 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1518
1548
|
output_shape.insert(0, axis_size)
|
|
1519
1549
|
ones_shape.insert(0, axis_size)
|
|
1520
1550
|
|
|
1521
|
-
|
|
1551
|
+
indexing, _ = indexing_bdim
|
|
1552
|
+
|
|
1553
|
+
if indexing == Indexing.xy.value:
|
|
1522
1554
|
output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
|
|
1523
1555
|
shape = tuple(output_shape)
|
|
1524
1556
|
|
|
@@ -1531,7 +1563,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1531
1563
|
for each_arg in args:
|
|
1532
1564
|
x, bdim = each_arg
|
|
1533
1565
|
x = _bdim_at_front(x, bdim, axis_size)
|
|
1534
|
-
shape_index = (1 - index) if (index <= 1 and indexing ==
|
|
1566
|
+
shape_index = (1 - index) if (index <= 1 and indexing == Indexing.xy.value) else index
|
|
1535
1567
|
ones_shape[shape_index + 1] = output_shape[shape_index + 1]
|
|
1536
1568
|
x = P.Reshape()(x, tuple(ones_shape))
|
|
1537
1569
|
output = P.Mul()(x, ones_tensor)
|
|
@@ -1889,10 +1921,6 @@ def get_slice_vmap_rule(prim, axis_size):
|
|
|
1889
1921
|
@vmap_rules_getters.register(P.Squeeze)
|
|
1890
1922
|
def get_squeeze_vmap_rule(prim, axis_size):
|
|
1891
1923
|
"""VmapRule for `Squeeze`."""
|
|
1892
|
-
if hasattr(prim, 'axis'):
|
|
1893
|
-
prim_axis = prim.axis
|
|
1894
|
-
else:
|
|
1895
|
-
prim_axis = None
|
|
1896
1924
|
|
|
1897
1925
|
@_primexpr
|
|
1898
1926
|
def move_axis(axes):
|
|
@@ -1911,27 +1939,26 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1911
1939
|
new_axis += (i,)
|
|
1912
1940
|
return new_axis
|
|
1913
1941
|
|
|
1914
|
-
def vmap_rule(x_bdim):
|
|
1915
|
-
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
1942
|
+
def vmap_rule(x_bdim, axis_bdim):
|
|
1943
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim)
|
|
1916
1944
|
if is_all_none:
|
|
1917
1945
|
return result
|
|
1918
1946
|
|
|
1919
1947
|
x, x_dim = x_bdim
|
|
1948
|
+
axis, _ = axis_bdim
|
|
1920
1949
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
1921
1950
|
|
|
1922
|
-
if
|
|
1951
|
+
if axis is None:
|
|
1923
1952
|
if axis_size == 1:
|
|
1924
1953
|
new_axis = generate_all_axis_except_first(F.rank(x))
|
|
1925
|
-
|
|
1926
|
-
out = batch_squeeze(x)
|
|
1954
|
+
out = prim(x, new_axis)
|
|
1927
1955
|
return out, 0
|
|
1928
1956
|
|
|
1929
|
-
out = prim(x)
|
|
1957
|
+
out = prim(x, axis)
|
|
1930
1958
|
return out, 0
|
|
1931
1959
|
|
|
1932
|
-
new_axis = move_axis(
|
|
1933
|
-
|
|
1934
|
-
out = batch_squeeze(x)
|
|
1960
|
+
new_axis = move_axis(axis)
|
|
1961
|
+
out = prim(x, new_axis)
|
|
1935
1962
|
return out, 0
|
|
1936
1963
|
|
|
1937
1964
|
return vmap_rule
|
mindspore/ops/_vmap/vmap_base.py
CHANGED
|
@@ -512,8 +512,6 @@ _ops_vmap_clone_prim_dict = {
|
|
|
512
512
|
"ApplyAdagradV2": P.ApplyAdagradV2,
|
|
513
513
|
"UniformCandidateSampler": UniformCandidateSampler,
|
|
514
514
|
"UniqueWithPad": P.UniqueWithPad,
|
|
515
|
-
"CdistGrad": G.CdistGrad,
|
|
516
|
-
"Cdist": P.Cdist,
|
|
517
515
|
"STFT": math_ops.STFT,
|
|
518
516
|
"Conv2D": P.Conv2D,
|
|
519
517
|
"Conv3D": P.Conv3D,
|
|
@@ -25,7 +25,9 @@ from mindspore.ops.primitive import _primexpr
|
|
|
25
25
|
from mindspore.ops.function import _VmapGeneralRule
|
|
26
26
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
|
|
27
27
|
_bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
|
|
28
|
-
from mindspore.ops
|
|
28
|
+
from mindspore.ops import auto_generate as gen
|
|
29
|
+
from mindspore._c_expression import FormatEnum as Format
|
|
30
|
+
from mindspore._c_expression import ReductionEnum as Reduction
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@vmap_rules_getters.register(G.NLLLossGrad)
|
|
@@ -225,33 +227,35 @@ def get_max_pool3d_grad_with_argmax_vmap_rule(prim, axis_size):
|
|
|
225
227
|
return vmap_rule
|
|
226
228
|
|
|
227
229
|
|
|
228
|
-
@vmap_rules_getters.register(
|
|
230
|
+
@vmap_rules_getters.register(gen.CdistGrad)
|
|
229
231
|
def get_cdist_grad_vmap_rule(prim, axis_size):
|
|
230
232
|
"""VmapRule for `cdist grad` operation."""
|
|
231
|
-
if
|
|
232
|
-
batch_rank = prim.batch_rank + 1
|
|
233
|
+
if prim.has_label("batch_rank"):
|
|
234
|
+
batch_rank = prim.get_label("batch_rank") + 1
|
|
233
235
|
else:
|
|
234
236
|
batch_rank = 1
|
|
235
237
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
+
prim = prim.clone()
|
|
239
|
+
prim.set_label('batch_rank', batch_rank)
|
|
238
240
|
|
|
239
|
-
def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim):
|
|
240
|
-
is_all_none, result = vmap_general_preprocess(
|
|
241
|
-
|
|
241
|
+
def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim):
|
|
242
|
+
is_all_none, result = vmap_general_preprocess(
|
|
243
|
+
prim, grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim
|
|
244
|
+
)
|
|
242
245
|
if is_all_none:
|
|
243
246
|
return result
|
|
244
247
|
grad, grad_dim = grad_bdim
|
|
245
248
|
x, x_dim = x_bdim
|
|
246
249
|
y, y_dim = y_bdim
|
|
247
250
|
cdist, cdist_dim = cdist_bdim
|
|
251
|
+
p, _ = p_bdim
|
|
248
252
|
|
|
249
253
|
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
250
254
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
251
255
|
y = _bdim_at_front(y, y_dim, axis_size)
|
|
252
256
|
cdist = _bdim_at_front(cdist, cdist_dim, axis_size)
|
|
253
257
|
|
|
254
|
-
out =
|
|
258
|
+
out = prim(grad, x, y, cdist, p)
|
|
255
259
|
return out, 0
|
|
256
260
|
|
|
257
261
|
return vmap_rule
|
|
@@ -673,10 +677,11 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
673
677
|
else:
|
|
674
678
|
_raise_value_error("The prim name must be `GridSampler2D` or `GridSampler3D`, but got {}.".format(prim_name))
|
|
675
679
|
|
|
676
|
-
|
|
677
|
-
|
|
680
|
+
def vmap_rule(grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
|
|
681
|
+
output_mask_bdim):
|
|
678
682
|
is_all_none, result = vmap_general_preprocess(
|
|
679
|
-
prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim
|
|
683
|
+
prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
|
|
684
|
+
output_mask_bdim)
|
|
680
685
|
if is_all_none:
|
|
681
686
|
return result
|
|
682
687
|
|
|
@@ -686,6 +691,7 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
686
691
|
interpolation_mode, _ = interpolation_mode_bdim
|
|
687
692
|
padding_mode, _ = padding_mode_bdim
|
|
688
693
|
align_corners, _ = align_corners_bdim
|
|
694
|
+
output_mask, _ = output_mask_bdim
|
|
689
695
|
|
|
690
696
|
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
691
697
|
grad_shape = F.shape(grad)
|
|
@@ -699,7 +705,8 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
699
705
|
grid_shape = F.shape(grid)
|
|
700
706
|
grid = F.reshape(grid, (-1,) + grid_shape[non_batch_dim_index:])
|
|
701
707
|
|
|
702
|
-
dx, dgrid = prim(grad, input_x, grid, interpolation_mode,
|
|
708
|
+
dx, dgrid = prim(grad, input_x, grid, interpolation_mode,
|
|
709
|
+
padding_mode, align_corners, output_mask)
|
|
703
710
|
dx_shape = F.shape(dx)
|
|
704
711
|
dx_return_shape = input_x_shape[:non_batch_dim_index] + dx_shape[non_batch_dim_index:]
|
|
705
712
|
dx = F.reshape(dx, dx_return_shape)
|
|
@@ -19,6 +19,7 @@ from __future__ import absolute_import
|
|
|
19
19
|
import mindspore.numpy as mnp
|
|
20
20
|
from mindspore.ops import operations as P
|
|
21
21
|
from mindspore.ops import functional as F
|
|
22
|
+
from mindspore.ops import auto_generate as gen
|
|
22
23
|
from mindspore.ops.auto_generate import MatMulExt
|
|
23
24
|
from mindspore.ops.primitive import _primexpr
|
|
24
25
|
from mindspore.common import Tensor
|
|
@@ -29,7 +30,7 @@ from mindspore.ops.primitive import Primitive
|
|
|
29
30
|
from mindspore.ops.function import _VmapGeneralRule
|
|
30
31
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
|
|
31
32
|
get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
|
|
32
|
-
|
|
33
|
+
_bdim_at_any, _get_reduce_batch_axis, _get_reduce_out_dim
|
|
33
34
|
from mindspore.ops.operations.math_ops import Bernoulli, BesselI0, BesselI1, BesselJ0, BesselJ1, \
|
|
34
35
|
BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1, BesselK1e, Median
|
|
35
36
|
|
|
@@ -128,28 +129,29 @@ def get_addcxxx_vmap_rule(prim, axis_size):
|
|
|
128
129
|
return vmap_rule
|
|
129
130
|
|
|
130
131
|
|
|
131
|
-
@vmap_rules_getters.register(
|
|
132
|
+
@vmap_rules_getters.register(gen.Cdist)
|
|
132
133
|
def get_cdist_vmap_rule(prim, axis_size):
|
|
133
134
|
"""VmapRule for `cdist` operation."""
|
|
134
|
-
if
|
|
135
|
-
batch_rank = prim.batch_rank + 1
|
|
135
|
+
if prim.has_label("batch_rank"):
|
|
136
|
+
batch_rank = prim.get_label("batch_rank") + 1
|
|
136
137
|
else:
|
|
137
138
|
batch_rank = 1
|
|
138
139
|
|
|
139
|
-
|
|
140
|
-
|
|
140
|
+
prim = prim.clone()
|
|
141
|
+
prim.set_label('batch_rank', batch_rank)
|
|
141
142
|
|
|
142
|
-
def vmap_rule(x_bdim, y_bdim):
|
|
143
|
+
def vmap_rule(x_bdim, y_bdim, p_bdim):
|
|
143
144
|
x, x_dim = x_bdim
|
|
144
145
|
y, y_dim = y_bdim
|
|
146
|
+
p, _ = p_bdim
|
|
145
147
|
|
|
146
|
-
if x_dim is None and y_dim is None:
|
|
148
|
+
if x_dim is None and y_dim is None and p is None:
|
|
147
149
|
out = prim(x, y)
|
|
148
150
|
return (out, None)
|
|
149
151
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
150
152
|
y = _bdim_at_front(y, y_dim, axis_size)
|
|
151
153
|
|
|
152
|
-
out =
|
|
154
|
+
out = prim(x, y, p)
|
|
153
155
|
return out, 0
|
|
154
156
|
|
|
155
157
|
return vmap_rule
|
|
@@ -559,20 +561,17 @@ def get_index_add_vmap_rule(prim, axis_size):
|
|
|
559
561
|
@vmap_rules_getters.register(linalg_ops.Svd)
|
|
560
562
|
def get_svd_vmap_rule(prim, axis_size):
|
|
561
563
|
"""VmapRule for 'Svd' operation."""
|
|
562
|
-
if isinstance(prim, str):
|
|
563
|
-
prim = Primitive(prim)
|
|
564
|
-
compute_uv = True
|
|
565
|
-
else:
|
|
566
|
-
compute_uv = prim.compute_uv
|
|
567
564
|
|
|
568
|
-
def vmap_rule(x_bdim):
|
|
565
|
+
def vmap_rule(x_bdim, full_matrices_bdim, compute_uv_bdim):
|
|
569
566
|
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
570
567
|
if is_all_none:
|
|
571
568
|
return result
|
|
572
569
|
|
|
573
570
|
x, x_dim = x_bdim
|
|
571
|
+
full_matrices, _ = full_matrices_bdim
|
|
572
|
+
compute_uv, _ = compute_uv_bdim
|
|
574
573
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
575
|
-
s, u, v = prim(x)
|
|
574
|
+
s, u, v = prim(x, full_matrices, compute_uv)
|
|
576
575
|
if compute_uv:
|
|
577
576
|
return (s, 0), (u, 0), (v, 0)
|
|
578
577
|
return (s, 0), (u, None), (v, None)
|