mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -20,14 +20,133 @@ import mindspore as ms
|
|
|
20
20
|
from mindspore import ops
|
|
21
21
|
from mindspore.common.tensor import Tensor
|
|
22
22
|
from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple
|
|
23
|
-
from mindspore.ops_generate.gen_ops_inner_prim import TupleToList, ListToTuple
|
|
24
23
|
from mindspore._c_expression import OpDtype
|
|
24
|
+
from mindspore._c_expression import typing
|
|
25
|
+
from mindspore._c_expression import op_enum
|
|
26
|
+
from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register
|
|
25
27
|
|
|
26
28
|
tensor_to_tuple_ = TensorToTuple()
|
|
29
|
+
tensor_to_scalar_ = TensorToScalar()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TupleToList(Primitive):
|
|
33
|
+
r"""
|
|
34
|
+
Convert tuple to list.
|
|
35
|
+
|
|
36
|
+
Inputs:
|
|
37
|
+
- **x** (tuple) - The input
|
|
38
|
+
|
|
39
|
+
Outputs:
|
|
40
|
+
List, has the same elements as the `input`.
|
|
41
|
+
|
|
42
|
+
Supported Platforms:
|
|
43
|
+
``CPU``
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
>>> from mindspore.ops._utils.arg_dtype_cast import TupleToList
|
|
47
|
+
>>> x = (1, 2, 3)
|
|
48
|
+
>>> result = TupleToList()(x)
|
|
49
|
+
>>> print(result)
|
|
50
|
+
[1, 2, 3]
|
|
51
|
+
"""
|
|
52
|
+
@prim_arg_register
|
|
53
|
+
def __init__(self):
|
|
54
|
+
"""Initialize TupleToList"""
|
|
55
|
+
|
|
56
|
+
def __call__(self, input):
|
|
57
|
+
return list(input)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ListToTuple(Primitive):
|
|
61
|
+
r"""
|
|
62
|
+
Convert list to tuple.
|
|
63
|
+
|
|
64
|
+
Inputs:
|
|
65
|
+
- **x** (list) - The input
|
|
66
|
+
|
|
67
|
+
Outputs:
|
|
68
|
+
Tuple, has the same elements as the `input`.
|
|
69
|
+
|
|
70
|
+
Supported Platforms:
|
|
71
|
+
``CPU``
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
>>> from mindspore.ops._utils.arg_dtype_cast import ListToTuple
|
|
75
|
+
>>> x = [1, 2, 3]
|
|
76
|
+
>>> result = ListToTuple()(x)
|
|
77
|
+
>>> print(result)
|
|
78
|
+
(1, 2, 3)
|
|
79
|
+
"""
|
|
80
|
+
@prim_arg_register
|
|
81
|
+
def __init__(self):
|
|
82
|
+
"""Initialize TupleToList"""
|
|
83
|
+
|
|
84
|
+
def __call__(self, input):
|
|
85
|
+
return tuple(input)
|
|
86
|
+
|
|
87
|
+
|
|
27
88
|
tuple_to_list = TupleToList()
|
|
28
89
|
list_to_tuple = ListToTuple()
|
|
29
90
|
|
|
30
91
|
|
|
92
|
+
class DtypeToEnum(Primitive):
|
|
93
|
+
r"""
|
|
94
|
+
Convert mindspore dtype to enum.
|
|
95
|
+
|
|
96
|
+
Inputs:
|
|
97
|
+
- **op_name** (str) - The op name
|
|
98
|
+
- **arg_name** (str) - The arg name
|
|
99
|
+
- **dtype** (mindspore.dtype) - The data type.
|
|
100
|
+
|
|
101
|
+
Outputs:
|
|
102
|
+
An integer.
|
|
103
|
+
|
|
104
|
+
Supported Platforms:
|
|
105
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
@prim_attr_register
|
|
109
|
+
def __init__(self):
|
|
110
|
+
"""Initialize"""
|
|
111
|
+
|
|
112
|
+
def __call__(self, op_name, arg_name, dtype):
|
|
113
|
+
"""Run in PyNative mode"""
|
|
114
|
+
if not isinstance(dtype, typing.Type):
|
|
115
|
+
raise TypeError(
|
|
116
|
+
f"For '{op_name}', the input '{arg_name}' should be mindspore dtype, but got {dtype}.")
|
|
117
|
+
return typing.type_to_type_id(dtype)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class StringToEnum(Primitive):
|
|
121
|
+
r"""
|
|
122
|
+
Convert string to enum.
|
|
123
|
+
|
|
124
|
+
Inputs:
|
|
125
|
+
- **op_name** (str) - The op name
|
|
126
|
+
- **arg_name** (str) - The arg name
|
|
127
|
+
- **enum_str** (str) - The str data.
|
|
128
|
+
|
|
129
|
+
Outputs:
|
|
130
|
+
An integer.
|
|
131
|
+
|
|
132
|
+
Supported Platforms:
|
|
133
|
+
``CPU``
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
@prim_attr_register
|
|
137
|
+
def __init__(self):
|
|
138
|
+
"""Initialize"""
|
|
139
|
+
|
|
140
|
+
def __call__(self, op_name, arg_name, enum_str):
|
|
141
|
+
"""Run in PyNative mode"""
|
|
142
|
+
if enum_str is None:
|
|
143
|
+
return None
|
|
144
|
+
if not isinstance(enum_str, str):
|
|
145
|
+
raise TypeError(
|
|
146
|
+
f"For '{op_name}', the input '{arg_name}' should be a str, but got {type(enum_str)}.")
|
|
147
|
+
return op_enum.str_to_enum(op_name, arg_name, enum_str)
|
|
148
|
+
|
|
149
|
+
|
|
31
150
|
def int_to_float(data):
|
|
32
151
|
return float(data)
|
|
33
152
|
|
|
@@ -184,7 +303,7 @@ def get_support_dtype_list(src_type, dst_type):
|
|
|
184
303
|
return support_list
|
|
185
304
|
|
|
186
305
|
|
|
187
|
-
def
|
|
306
|
+
def tensor_to_number(data, dst_type):
|
|
188
307
|
"""Convert tensor to python number"""
|
|
189
308
|
if dst_type == DT_INT_VAL:
|
|
190
309
|
data = ops.cast(data, ms.int64)
|
|
@@ -197,7 +316,7 @@ def to_py_number(data, dst_type):
|
|
|
197
316
|
data = ops.cast(data, ms.int64)
|
|
198
317
|
elif src_type in (ms.bfloat16, ms.float16, ms.float32, ms.float64):
|
|
199
318
|
data = ops.cast(data, ms.float32)
|
|
200
|
-
return
|
|
319
|
+
return tensor_to_scalar_(data)
|
|
201
320
|
|
|
202
321
|
|
|
203
322
|
def do_type_cast(data, dst_type):
|
|
@@ -230,7 +349,7 @@ def do_type_cast(data, dst_type):
|
|
|
230
349
|
return list_to_tensor(data)
|
|
231
350
|
elif is_number(dst_type):
|
|
232
351
|
if isinstance(data, Tensor):
|
|
233
|
-
return
|
|
352
|
+
return tensor_to_number(data, dst_type)
|
|
234
353
|
raise TypeError("Type conversion failed.")
|
|
235
354
|
|
|
236
355
|
|
|
@@ -14,13 +14,11 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Operator argument handle function."""
|
|
16
16
|
|
|
17
|
-
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
|
|
18
|
-
# Enum Class:
|
|
19
|
-
from mindspore._c_expression import FormatEnum as Format
|
|
20
|
-
from mindspore._c_expression import ReductionEnum as Reduction
|
|
21
17
|
from mindspore.common import Tensor
|
|
22
18
|
from mindspore.common import dtype as mstype
|
|
23
19
|
|
|
20
|
+
from .arg_dtype_cast import DtypeToEnum, StringToEnum
|
|
21
|
+
|
|
24
22
|
|
|
25
23
|
def arg_invalid_info(op_name, arg_name, arg_val):
|
|
26
24
|
"""
|
|
@@ -116,67 +114,6 @@ def to_2d_paddings(op_name, arg_name, pad):
|
|
|
116
114
|
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
117
115
|
|
|
118
116
|
|
|
119
|
-
def to_paddings(op_name, arg_name, pad):
|
|
120
|
-
"""
|
|
121
|
-
convert paddings: int -> tuple[int*4].
|
|
122
|
-
"""
|
|
123
|
-
if isinstance(pad, int):
|
|
124
|
-
return (pad,) * 4
|
|
125
|
-
if isinstance(pad, (tuple, list)):
|
|
126
|
-
return pad
|
|
127
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def to_3d_kernel_size(op_name, arg_name, kernel_size):
|
|
131
|
-
"""
|
|
132
|
-
convert 3d kernel_size: int/tuple[int*6] -> tuple[int*3].
|
|
133
|
-
"""
|
|
134
|
-
if isinstance(kernel_size, int):
|
|
135
|
-
return (kernel_size, kernel_size, kernel_size)
|
|
136
|
-
if isinstance(kernel_size, (tuple, list)):
|
|
137
|
-
if len(kernel_size) == 5:
|
|
138
|
-
return (kernel_size[2], kernel_size[3], kernel_size[4])
|
|
139
|
-
return kernel_size
|
|
140
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size))
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def to_3d_strides(op_name, arg_name, stride):
|
|
144
|
-
"""
|
|
145
|
-
convert 3d stride: int/tuple[int*6] -> tuple[int*3].
|
|
146
|
-
"""
|
|
147
|
-
if isinstance(stride, int):
|
|
148
|
-
return (stride, stride, stride)
|
|
149
|
-
if isinstance(stride, (tuple, list)):
|
|
150
|
-
if len(stride) == 5:
|
|
151
|
-
return (stride[2], stride[3], stride[4])
|
|
152
|
-
return stride
|
|
153
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, stride))
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def to_3d_dilations(op_name, arg_name, dilation):
|
|
157
|
-
"""
|
|
158
|
-
convert 3d dilation: int/tuple[int*6] -> tuple[int*3].
|
|
159
|
-
"""
|
|
160
|
-
if isinstance(dilation, int):
|
|
161
|
-
return (dilation, dilation, dilation)
|
|
162
|
-
if isinstance(dilation, (tuple, list)):
|
|
163
|
-
if len(dilation) == 5:
|
|
164
|
-
return (dilation[2], dilation[3], dilation[4])
|
|
165
|
-
return dilation
|
|
166
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, dilation))
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def to_3d_paddings(op_name, arg_name, pad):
|
|
170
|
-
"""
|
|
171
|
-
convert 3d paddings: int -> tuple[int*6].
|
|
172
|
-
"""
|
|
173
|
-
if isinstance(pad, int):
|
|
174
|
-
return (pad,) * 6
|
|
175
|
-
if isinstance(pad, (tuple, list)):
|
|
176
|
-
return pad
|
|
177
|
-
raise ValueError(arg_invalid_info(op_name, arg_name, pad))
|
|
178
|
-
|
|
179
|
-
|
|
180
117
|
def generator_handler(op_name, arg_name, inputs):
|
|
181
118
|
"""
|
|
182
119
|
convert constant value in tuple to tensor
|
|
@@ -189,6 +126,7 @@ def generator_handler(op_name, arg_name, inputs):
|
|
|
189
126
|
new_inputs.append(input_)
|
|
190
127
|
return tuple(new_inputs)
|
|
191
128
|
|
|
129
|
+
|
|
192
130
|
dtype_to_type_id = DtypeToEnum()
|
|
193
131
|
|
|
194
132
|
# string to enum
|
|
@@ -15,12 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
"""array_ops vmap impl."""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
|
+
from enum import Enum
|
|
18
19
|
|
|
19
20
|
import mindspore
|
|
20
21
|
import mindspore.numpy as mnp
|
|
21
22
|
from mindspore import ops
|
|
22
23
|
from mindspore.common import Tensor
|
|
23
|
-
from mindspore._c_expression import
|
|
24
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
24
25
|
from mindspore.ops import operations as P
|
|
25
26
|
from mindspore.ops import functional as F
|
|
26
27
|
from mindspore.ops.primitive import constexpr, _primexpr
|
|
@@ -140,6 +141,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
140
141
|
the generated prefix is a Tensor([[[0], [0]],
|
|
141
142
|
[[1], [1]]])
|
|
142
143
|
"""
|
|
144
|
+
cast_op = P.Cast()
|
|
145
|
+
|
|
143
146
|
def _check(indices_shape):
|
|
144
147
|
if not indices_shape:
|
|
145
148
|
raise ValueError("indices_shape is empty in _get_prefix.")
|
|
@@ -147,8 +150,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
147
150
|
_check(indices_shape)
|
|
148
151
|
indices_len = len(indices_shape)
|
|
149
152
|
if indices_len == 1:
|
|
150
|
-
prefix = P.Range()(
|
|
151
|
-
return prefix
|
|
153
|
+
prefix = P.Range()(0, axis_size, 1)
|
|
154
|
+
return cast_op(prefix, indices_dtype)
|
|
152
155
|
|
|
153
156
|
indices_end = indices_len - 1
|
|
154
157
|
prefix_shape = ()
|
|
@@ -163,9 +166,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
163
166
|
else:
|
|
164
167
|
expand_shape = expand_shape + (1,)
|
|
165
168
|
|
|
166
|
-
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(
|
|
167
|
-
|
|
168
|
-
return prefix
|
|
169
|
+
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(0, axis_size, 1), expand_shape))
|
|
170
|
+
return cast_op(prefix, indices_dtype)
|
|
169
171
|
|
|
170
172
|
|
|
171
173
|
@vmap_rules_getters.register(P.Transpose)
|
|
@@ -1488,16 +1490,19 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1488
1490
|
"""VmapRule for `P.Meshgrid` operation."""
|
|
1489
1491
|
if isinstance(prim, str):
|
|
1490
1492
|
prim = Primitive(prim)
|
|
1491
|
-
indexing = prim.indexing
|
|
1492
1493
|
|
|
1493
|
-
|
|
1494
|
-
|
|
1494
|
+
class Indexing(Enum):
|
|
1495
|
+
ij = 0
|
|
1496
|
+
xy = 1
|
|
1497
|
+
|
|
1498
|
+
def vmap_rule(inputs_bdim, indexing_bdim):
|
|
1499
|
+
is_all_none, result = vmap_general_preprocess(prim, inputs_bdim, indexing_bdim)
|
|
1495
1500
|
if is_all_none:
|
|
1496
1501
|
return result
|
|
1497
1502
|
|
|
1498
1503
|
if not isinstance(inputs_bdim, (tuple)):
|
|
1499
1504
|
_raise_value_error("The inputs of P.Meshgrid is not tuple.")
|
|
1500
|
-
args = inputs_bdim
|
|
1505
|
+
args = inputs_bdim
|
|
1501
1506
|
if len(args) <= 1:
|
|
1502
1507
|
_raise_value_error(
|
|
1503
1508
|
"The input number of P.Meshgrid must be greater than 1.")
|
|
@@ -1518,7 +1523,9 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1518
1523
|
output_shape.insert(0, axis_size)
|
|
1519
1524
|
ones_shape.insert(0, axis_size)
|
|
1520
1525
|
|
|
1521
|
-
|
|
1526
|
+
indexing, _ = indexing_bdim
|
|
1527
|
+
|
|
1528
|
+
if indexing == Indexing.xy.value:
|
|
1522
1529
|
output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
|
|
1523
1530
|
shape = tuple(output_shape)
|
|
1524
1531
|
|
|
@@ -1531,7 +1538,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1531
1538
|
for each_arg in args:
|
|
1532
1539
|
x, bdim = each_arg
|
|
1533
1540
|
x = _bdim_at_front(x, bdim, axis_size)
|
|
1534
|
-
shape_index = (1 - index) if (index <= 1 and indexing ==
|
|
1541
|
+
shape_index = (1 - index) if (index <= 1 and indexing == Indexing.xy.value) else index
|
|
1535
1542
|
ones_shape[shape_index + 1] = output_shape[shape_index + 1]
|
|
1536
1543
|
x = P.Reshape()(x, tuple(ones_shape))
|
|
1537
1544
|
output = P.Mul()(x, ones_tensor)
|
|
@@ -1889,10 +1896,6 @@ def get_slice_vmap_rule(prim, axis_size):
|
|
|
1889
1896
|
@vmap_rules_getters.register(P.Squeeze)
|
|
1890
1897
|
def get_squeeze_vmap_rule(prim, axis_size):
|
|
1891
1898
|
"""VmapRule for `Squeeze`."""
|
|
1892
|
-
if hasattr(prim, 'axis'):
|
|
1893
|
-
prim_axis = prim.axis
|
|
1894
|
-
else:
|
|
1895
|
-
prim_axis = None
|
|
1896
1899
|
|
|
1897
1900
|
@_primexpr
|
|
1898
1901
|
def move_axis(axes):
|
|
@@ -1911,27 +1914,26 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1911
1914
|
new_axis += (i,)
|
|
1912
1915
|
return new_axis
|
|
1913
1916
|
|
|
1914
|
-
def vmap_rule(x_bdim):
|
|
1915
|
-
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
1917
|
+
def vmap_rule(x_bdim, axis_bdim):
|
|
1918
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim)
|
|
1916
1919
|
if is_all_none:
|
|
1917
1920
|
return result
|
|
1918
1921
|
|
|
1919
1922
|
x, x_dim = x_bdim
|
|
1923
|
+
axis, _ = axis_bdim
|
|
1920
1924
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
1921
1925
|
|
|
1922
|
-
if
|
|
1926
|
+
if axis is None:
|
|
1923
1927
|
if axis_size == 1:
|
|
1924
1928
|
new_axis = generate_all_axis_except_first(F.rank(x))
|
|
1925
|
-
|
|
1926
|
-
out = batch_squeeze(x)
|
|
1929
|
+
out = prim(x, new_axis)
|
|
1927
1930
|
return out, 0
|
|
1928
1931
|
|
|
1929
|
-
out = prim(x)
|
|
1932
|
+
out = prim(x, axis)
|
|
1930
1933
|
return out, 0
|
|
1931
1934
|
|
|
1932
|
-
new_axis = move_axis(
|
|
1933
|
-
|
|
1934
|
-
out = batch_squeeze(x)
|
|
1935
|
+
new_axis = move_axis(axis)
|
|
1936
|
+
out = prim(x, new_axis)
|
|
1935
1937
|
return out, 0
|
|
1936
1938
|
|
|
1937
1939
|
return vmap_rule
|
mindspore/ops/_vmap/vmap_base.py
CHANGED
|
@@ -512,8 +512,6 @@ _ops_vmap_clone_prim_dict = {
|
|
|
512
512
|
"ApplyAdagradV2": P.ApplyAdagradV2,
|
|
513
513
|
"UniformCandidateSampler": UniformCandidateSampler,
|
|
514
514
|
"UniqueWithPad": P.UniqueWithPad,
|
|
515
|
-
"CdistGrad": G.CdistGrad,
|
|
516
|
-
"Cdist": P.Cdist,
|
|
517
515
|
"STFT": math_ops.STFT,
|
|
518
516
|
"Conv2D": P.Conv2D,
|
|
519
517
|
"Conv3D": P.Conv3D,
|
|
@@ -25,7 +25,9 @@ from mindspore.ops.primitive import _primexpr
|
|
|
25
25
|
from mindspore.ops.function import _VmapGeneralRule
|
|
26
26
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
|
|
27
27
|
_bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
|
|
28
|
-
from mindspore.ops
|
|
28
|
+
from mindspore.ops import auto_generate as gen
|
|
29
|
+
from mindspore._c_expression import FormatEnum as Format
|
|
30
|
+
from mindspore._c_expression import ReductionEnum as Reduction
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@vmap_rules_getters.register(G.NLLLossGrad)
|
|
@@ -225,33 +227,35 @@ def get_max_pool3d_grad_with_argmax_vmap_rule(prim, axis_size):
|
|
|
225
227
|
return vmap_rule
|
|
226
228
|
|
|
227
229
|
|
|
228
|
-
@vmap_rules_getters.register(
|
|
230
|
+
@vmap_rules_getters.register(gen.CdistGrad)
|
|
229
231
|
def get_cdist_grad_vmap_rule(prim, axis_size):
|
|
230
232
|
"""VmapRule for `cdist grad` operation."""
|
|
231
|
-
if
|
|
232
|
-
batch_rank = prim.batch_rank + 1
|
|
233
|
+
if prim.has_label("batch_rank"):
|
|
234
|
+
batch_rank = prim.get_label("batch_rank") + 1
|
|
233
235
|
else:
|
|
234
236
|
batch_rank = 1
|
|
235
237
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
+
prim = prim.clone()
|
|
239
|
+
prim.set_label('batch_rank', batch_rank)
|
|
238
240
|
|
|
239
|
-
def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim):
|
|
240
|
-
is_all_none, result = vmap_general_preprocess(
|
|
241
|
-
|
|
241
|
+
def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim):
|
|
242
|
+
is_all_none, result = vmap_general_preprocess(
|
|
243
|
+
prim, grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim
|
|
244
|
+
)
|
|
242
245
|
if is_all_none:
|
|
243
246
|
return result
|
|
244
247
|
grad, grad_dim = grad_bdim
|
|
245
248
|
x, x_dim = x_bdim
|
|
246
249
|
y, y_dim = y_bdim
|
|
247
250
|
cdist, cdist_dim = cdist_bdim
|
|
251
|
+
p, _ = p_bdim
|
|
248
252
|
|
|
249
253
|
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
250
254
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
251
255
|
y = _bdim_at_front(y, y_dim, axis_size)
|
|
252
256
|
cdist = _bdim_at_front(cdist, cdist_dim, axis_size)
|
|
253
257
|
|
|
254
|
-
out =
|
|
258
|
+
out = prim(grad, x, y, cdist, p)
|
|
255
259
|
return out, 0
|
|
256
260
|
|
|
257
261
|
return vmap_rule
|
|
@@ -673,10 +677,11 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
673
677
|
else:
|
|
674
678
|
_raise_value_error("The prim name must be `GridSampler2D` or `GridSampler3D`, but got {}.".format(prim_name))
|
|
675
679
|
|
|
676
|
-
|
|
677
|
-
|
|
680
|
+
def vmap_rule(grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
|
|
681
|
+
output_mask_bdim):
|
|
678
682
|
is_all_none, result = vmap_general_preprocess(
|
|
679
|
-
prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim
|
|
683
|
+
prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
|
|
684
|
+
output_mask_bdim)
|
|
680
685
|
if is_all_none:
|
|
681
686
|
return result
|
|
682
687
|
|
|
@@ -686,6 +691,7 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
686
691
|
interpolation_mode, _ = interpolation_mode_bdim
|
|
687
692
|
padding_mode, _ = padding_mode_bdim
|
|
688
693
|
align_corners, _ = align_corners_bdim
|
|
694
|
+
output_mask, _ = output_mask_bdim
|
|
689
695
|
|
|
690
696
|
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
691
697
|
grad_shape = F.shape(grad)
|
|
@@ -699,7 +705,8 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
699
705
|
grid_shape = F.shape(grid)
|
|
700
706
|
grid = F.reshape(grid, (-1,) + grid_shape[non_batch_dim_index:])
|
|
701
707
|
|
|
702
|
-
dx, dgrid = prim(grad, input_x, grid, interpolation_mode,
|
|
708
|
+
dx, dgrid = prim(grad, input_x, grid, interpolation_mode,
|
|
709
|
+
padding_mode, align_corners, output_mask)
|
|
703
710
|
dx_shape = F.shape(dx)
|
|
704
711
|
dx_return_shape = input_x_shape[:non_batch_dim_index] + dx_shape[non_batch_dim_index:]
|
|
705
712
|
dx = F.reshape(dx, dx_return_shape)
|
|
@@ -19,6 +19,7 @@ from __future__ import absolute_import
|
|
|
19
19
|
import mindspore.numpy as mnp
|
|
20
20
|
from mindspore.ops import operations as P
|
|
21
21
|
from mindspore.ops import functional as F
|
|
22
|
+
from mindspore.ops import auto_generate as gen
|
|
22
23
|
from mindspore.ops.auto_generate import MatMulExt
|
|
23
24
|
from mindspore.ops.primitive import _primexpr
|
|
24
25
|
from mindspore.common import Tensor
|
|
@@ -29,7 +30,7 @@ from mindspore.ops.primitive import Primitive
|
|
|
29
30
|
from mindspore.ops.function import _VmapGeneralRule
|
|
30
31
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
|
|
31
32
|
get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
|
|
32
|
-
|
|
33
|
+
_bdim_at_any, _get_reduce_batch_axis, _get_reduce_out_dim
|
|
33
34
|
from mindspore.ops.operations.math_ops import Bernoulli, BesselI0, BesselI1, BesselJ0, BesselJ1, \
|
|
34
35
|
BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1, BesselK1e, Median
|
|
35
36
|
|
|
@@ -128,28 +129,29 @@ def get_addcxxx_vmap_rule(prim, axis_size):
|
|
|
128
129
|
return vmap_rule
|
|
129
130
|
|
|
130
131
|
|
|
131
|
-
@vmap_rules_getters.register(
|
|
132
|
+
@vmap_rules_getters.register(gen.Cdist)
|
|
132
133
|
def get_cdist_vmap_rule(prim, axis_size):
|
|
133
134
|
"""VmapRule for `cdist` operation."""
|
|
134
|
-
if
|
|
135
|
-
batch_rank = prim.batch_rank + 1
|
|
135
|
+
if prim.has_label("batch_rank"):
|
|
136
|
+
batch_rank = prim.get_label("batch_rank") + 1
|
|
136
137
|
else:
|
|
137
138
|
batch_rank = 1
|
|
138
139
|
|
|
139
|
-
|
|
140
|
-
|
|
140
|
+
prim = prim.clone()
|
|
141
|
+
prim.set_label('batch_rank', batch_rank)
|
|
141
142
|
|
|
142
|
-
def vmap_rule(x_bdim, y_bdim):
|
|
143
|
+
def vmap_rule(x_bdim, y_bdim, p_bdim):
|
|
143
144
|
x, x_dim = x_bdim
|
|
144
145
|
y, y_dim = y_bdim
|
|
146
|
+
p, _ = p_bdim
|
|
145
147
|
|
|
146
|
-
if x_dim is None and y_dim is None:
|
|
148
|
+
if x_dim is None and y_dim is None and p is None:
|
|
147
149
|
out = prim(x, y)
|
|
148
150
|
return (out, None)
|
|
149
151
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
150
152
|
y = _bdim_at_front(y, y_dim, axis_size)
|
|
151
153
|
|
|
152
|
-
out =
|
|
154
|
+
out = prim(x, y, p)
|
|
153
155
|
return out, 0
|
|
154
156
|
|
|
155
157
|
return vmap_rule
|
|
@@ -559,20 +561,17 @@ def get_index_add_vmap_rule(prim, axis_size):
|
|
|
559
561
|
@vmap_rules_getters.register(linalg_ops.Svd)
|
|
560
562
|
def get_svd_vmap_rule(prim, axis_size):
|
|
561
563
|
"""VmapRule for 'Svd' operation."""
|
|
562
|
-
if isinstance(prim, str):
|
|
563
|
-
prim = Primitive(prim)
|
|
564
|
-
compute_uv = True
|
|
565
|
-
else:
|
|
566
|
-
compute_uv = prim.compute_uv
|
|
567
564
|
|
|
568
|
-
def vmap_rule(x_bdim):
|
|
565
|
+
def vmap_rule(x_bdim, full_matrices_bdim, compute_uv_bdim):
|
|
569
566
|
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
570
567
|
if is_all_none:
|
|
571
568
|
return result
|
|
572
569
|
|
|
573
570
|
x, x_dim = x_bdim
|
|
571
|
+
full_matrices, _ = full_matrices_bdim
|
|
572
|
+
compute_uv, _ = compute_uv_bdim
|
|
574
573
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
575
|
-
s, u, v = prim(x)
|
|
574
|
+
s, u, v = prim(x, full_matrices, compute_uv)
|
|
576
575
|
if compute_uv:
|
|
577
576
|
return (s, 0), (u, 0), (v, 0)
|
|
578
577
|
return (s, 0), (u, None), (v, None)
|