mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +564 -395
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -25,7 +25,7 @@ from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, Mu
|
|
|
25
25
|
from mindspore.ops.composite.env_ops import env_get
|
|
26
26
|
from mindspore.ops.function.clip_func import clip_by_global_norm
|
|
27
27
|
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
|
28
|
-
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
|
|
28
|
+
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like, _ones_like_for_grad
|
|
29
29
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
30
30
|
from mindspore.ops.function.random_func import normal, laplace, uniform, gamma, poisson, multinomial
|
|
31
31
|
from mindspore.ops.composite.math_ops import matmul, cummin, mm
|
|
@@ -46,6 +46,7 @@ __all__ = [
|
|
|
46
46
|
'hyper_add',
|
|
47
47
|
'zeros_like',
|
|
48
48
|
'ones_like',
|
|
49
|
+
'_ones_like_for_grad',
|
|
49
50
|
'zip_operation',
|
|
50
51
|
'normal',
|
|
51
52
|
'laplace',
|
mindspore/ops/composite/base.py
CHANGED
|
@@ -24,6 +24,7 @@ import numpy as np
|
|
|
24
24
|
import mindspore as ms
|
|
25
25
|
from mindspore import context
|
|
26
26
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
27
|
+
from mindspore.common.tensor import Tensor
|
|
27
28
|
from mindspore.parallel._utils import _grads_divided_by_device_num_if_recomputation
|
|
28
29
|
from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
|
|
29
30
|
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
|
@@ -35,7 +36,6 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|
|
35
36
|
from mindspore.common import dtype as mstype
|
|
36
37
|
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
|
37
38
|
from mindspore.common.api import _add_flags, _core
|
|
38
|
-
from mindspore.ops.primitive import Primitive
|
|
39
39
|
from mindspore.ops import signature as sig
|
|
40
40
|
|
|
41
41
|
__all__ = [TupleAdd_, ListAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_,
|
|
@@ -358,16 +358,14 @@ class GradOperation(GradOperation_):
|
|
|
358
358
|
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
|
|
359
359
|
# grad in GRAPH_MODE.
|
|
360
360
|
if context.get_context("mode") == context.GRAPH_MODE:
|
|
361
|
-
dynamic_shape_inputs = None
|
|
362
361
|
if isinstance(fn, ms.nn.Cell):
|
|
363
|
-
dynamic_shape_inputs = fn.get_inputs()
|
|
364
362
|
fn.grad_ops_label = True
|
|
365
363
|
if self.get_by_list:
|
|
366
|
-
@jit
|
|
364
|
+
@jit
|
|
367
365
|
def after_grad(*args, **kwargs):
|
|
368
366
|
return grad_(fn, weights)(*args, **kwargs)
|
|
369
367
|
else:
|
|
370
|
-
@jit
|
|
368
|
+
@jit
|
|
371
369
|
def after_grad(*args, **kwargs):
|
|
372
370
|
return grad_(fn)(*args, **kwargs)
|
|
373
371
|
elif self.pynative_:
|
|
@@ -578,11 +576,7 @@ class _Grad(GradOperation_):
|
|
|
578
576
|
outputs = fn(*args, **kwargs)
|
|
579
577
|
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
|
580
578
|
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
|
|
581
|
-
|
|
582
|
-
stop_gradient = Primitive("StopGradient")
|
|
583
|
-
for item in outputs[1:]:
|
|
584
|
-
res += (stop_gradient(item),)
|
|
585
|
-
return res
|
|
579
|
+
return outputs
|
|
586
580
|
|
|
587
581
|
grad_ = _Grad(self.get_all, self.get_by_list, self.sens_param, self.get_by_position, self.has_aux,
|
|
588
582
|
self.get_value, self.return_ids, self.merge_forward)
|
|
@@ -592,20 +586,17 @@ class _Grad(GradOperation_):
|
|
|
592
586
|
# In PYNATIVE_MODE calling Grad from functions decorated with 'jit', use the out layer after_grad do
|
|
593
587
|
# grad in GRAPH_MODE.
|
|
594
588
|
if context.get_context("mode") == context.GRAPH_MODE:
|
|
595
|
-
dynamic_shape_inputs = None
|
|
596
|
-
if isinstance(fn, ms.nn.Cell):
|
|
597
|
-
dynamic_shape_inputs = fn.get_inputs()
|
|
598
589
|
if self.get_by_position:
|
|
599
|
-
@jit
|
|
590
|
+
@jit
|
|
600
591
|
def after_grad(*args):
|
|
601
592
|
return grad_(fn, weights, grad_position)(*args)
|
|
602
593
|
else:
|
|
603
594
|
if self.get_by_list:
|
|
604
|
-
@jit
|
|
595
|
+
@jit
|
|
605
596
|
def after_grad(*args):
|
|
606
597
|
return grad_(fn, weights)(*args)
|
|
607
598
|
else:
|
|
608
|
-
@jit
|
|
599
|
+
@jit
|
|
609
600
|
def after_grad(*args):
|
|
610
601
|
return grad_(fn)(*args)
|
|
611
602
|
elif self.pynative_:
|
|
@@ -615,9 +606,12 @@ class _Grad(GradOperation_):
|
|
|
615
606
|
@_wrap_func
|
|
616
607
|
def after_grad(*args, **kwargs):
|
|
617
608
|
run_args, res = self._pynative_forward_run(fn, grad_, weights, *args, **kwargs)
|
|
618
|
-
|
|
609
|
+
if self.has_aux:
|
|
610
|
+
out = _pynative_executor.grad_aux(fn, grad_, weights, grad_position, *run_args)
|
|
611
|
+
else:
|
|
612
|
+
out = _pynative_executor.grad(fn, grad_, weights, grad_position, *run_args)
|
|
619
613
|
out = _grads_divided_by_device_num_if_recomputation(out)
|
|
620
|
-
if self.return_ids and out:
|
|
614
|
+
if self.return_ids and (isinstance(out, Tensor) or out) and out is not None:
|
|
621
615
|
out = _combine_with_ids(grad_position, weights, out)
|
|
622
616
|
if self.get_value:
|
|
623
617
|
return res, out
|
|
@@ -820,14 +814,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
820
814
|
self.register_fn(type_names, fn)
|
|
821
815
|
self.entries.append((types, fn))
|
|
822
816
|
return fn
|
|
823
|
-
|
|
824
817
|
return deco
|
|
825
818
|
|
|
826
|
-
def
|
|
819
|
+
def _register_default(self, convert_to_interpret=True):
|
|
827
820
|
def deco(fn):
|
|
828
|
-
|
|
821
|
+
if not convert_to_interpret:
|
|
822
|
+
self.register_default_fn(fn)
|
|
823
|
+
else:
|
|
824
|
+
self.default_func = fn
|
|
829
825
|
return fn
|
|
830
|
-
|
|
831
826
|
return deco
|
|
832
827
|
|
|
833
828
|
# pylint: disable=missing-docstring
|
|
@@ -843,7 +838,7 @@ class HyperMap(HyperMap_):
|
|
|
843
838
|
HyperMap will apply the set operation to input sequences.
|
|
844
839
|
|
|
845
840
|
Apply the operations to every element of the sequence or nested sequence. Different
|
|
846
|
-
from
|
|
841
|
+
from :class:`mindspore.ops.Map`, the `HyperMap` supports to apply on nested structure. The
|
|
847
842
|
`HyperMap` also supports dynamic sequences as input, but it does not extend this
|
|
848
843
|
support to nested dynamic sequences.
|
|
849
844
|
|
|
@@ -928,9 +923,9 @@ class Map(Map_):
|
|
|
928
923
|
Apply the operations to every element of the sequence.
|
|
929
924
|
|
|
930
925
|
Args:
|
|
931
|
-
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
|
926
|
+
ops (Union[MultitypeFuncGraph, None], optional): `ops` is the operation to apply. If `ops` is `None`,
|
|
932
927
|
the operations should be put in the first input of the instance. Default: ``None`` .
|
|
933
|
-
reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
|
|
928
|
+
reverse (bool, optional): The optimizer needs to be inverted in some scenarios to improve parallel performance,
|
|
934
929
|
general users please ignore. `Reverse` is the flag to decide if apply the operation reversely.
|
|
935
930
|
Only supported in graph mode. Default is ``False`` .
|
|
936
931
|
|
|
@@ -85,8 +85,9 @@ def matmul(x1, x2, dtype=None):
|
|
|
85
85
|
def mm(input, mat2):
|
|
86
86
|
r"""
|
|
87
87
|
Returns the matrix product of two arrays.
|
|
88
|
-
|
|
89
|
-
|
|
88
|
+
|
|
89
|
+
If `input` is a :math:`(n \times m)` tensor, `mat2` is a
|
|
90
|
+
:math:`(m \times p)` tensor, `out` will be a :math:`(n \times p)` tensor.
|
|
90
91
|
|
|
91
92
|
Note:
|
|
92
93
|
- This function cannot support broadcasting.
|
|
@@ -95,28 +96,17 @@ def mm(input, mat2):
|
|
|
95
96
|
|
|
96
97
|
Args:
|
|
97
98
|
input (Tensor): The first matrix of matrix multiplication.
|
|
98
|
-
The last dimension of `input` must be the same size as the first dimension of `mat2`.
|
|
99
99
|
mat2 (Tensor): The second matrix of matrix multiplication.
|
|
100
|
-
The last dimension of `input` must be the same size as the first dimension of `mat2`.
|
|
101
100
|
|
|
102
101
|
Returns:
|
|
103
|
-
Tensor or scalar
|
|
104
|
-
|
|
105
|
-
Raises:
|
|
106
|
-
ValueError: If the last dimension of `input` is not the same size as the
|
|
107
|
-
second-to-last dimension of `mat2`.
|
|
108
|
-
ValueError: If `input` or `mat2` is not a Tensor.
|
|
102
|
+
Tensor or scalar
|
|
109
103
|
|
|
110
104
|
Supported Platforms:
|
|
111
105
|
``Ascend`` ``GPU`` ``CPU``
|
|
112
106
|
|
|
113
107
|
Examples:
|
|
114
|
-
>>> import mindspore
|
|
115
|
-
>>>
|
|
116
|
-
>>> import numpy as np
|
|
117
|
-
>>> x1 = ms.Tensor(np.random.rand(2, 3), ms.float32)
|
|
118
|
-
>>> x2 = ms.Tensor(np.random.rand(3, 4), ms.float32)
|
|
119
|
-
>>> out = ops.mm(x1, x2)
|
|
108
|
+
>>> import mindspore
|
|
109
|
+
>>> out = mindspore.ops.mm(mindspore.ops.ones((2, 3)), mindspore.ops.ones((3, 4)))
|
|
120
110
|
>>> print(out.shape)
|
|
121
111
|
(2, 4)
|
|
122
112
|
"""
|
|
@@ -25,7 +25,7 @@ from mindspore.ops.composite.multitype_ops.mod_impl import mod
|
|
|
25
25
|
from mindspore.ops.composite.multitype_ops.getitem_impl import getitem
|
|
26
26
|
from mindspore.ops.composite.multitype_ops.setitem_impl import setitem
|
|
27
27
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
28
|
-
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
|
|
28
|
+
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like, _ones_like_for_grad
|
|
29
29
|
from mindspore.ops.composite.multitype_ops.equal_impl import equal
|
|
30
30
|
from mindspore.ops.composite.multitype_ops.not_equal_impl import not_equal
|
|
31
31
|
from mindspore.ops.composite.multitype_ops.less_impl import less
|
|
@@ -44,6 +44,7 @@ from mindspore.ops.composite.multitype_ops.right_shift_impl import right_shift
|
|
|
44
44
|
from mindspore.ops.composite.multitype_ops.uadd_impl import uadd
|
|
45
45
|
from mindspore.ops.composite.multitype_ops.in_impl import in_
|
|
46
46
|
from mindspore.ops.composite.multitype_ops.not_in_impl import not_in_
|
|
47
|
+
from mindspore.ops.composite.multitype_ops.invert_impl import invert
|
|
47
48
|
__all__ = [
|
|
48
49
|
'add',
|
|
49
50
|
'sub',
|
|
@@ -73,5 +74,7 @@ __all__ = [
|
|
|
73
74
|
'left_shift',
|
|
74
75
|
'right_shift',
|
|
75
76
|
'in_',
|
|
76
|
-
'not_in_'
|
|
77
|
+
'not_in_',
|
|
78
|
+
'invert',
|
|
79
|
+
'_ones_like_for_grad'
|
|
77
80
|
]
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
"""constexpr util"""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
from enum import IntEnum
|
|
19
|
-
|
|
19
|
+
import numpy as np
|
|
20
20
|
|
|
21
21
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
22
22
|
from mindspore.ops import functional as F
|
|
@@ -35,6 +35,8 @@ from mindspore import ops
|
|
|
35
35
|
from mindspore.ops.primitive import _primexpr
|
|
36
36
|
from mindspore import _checkparam as validator
|
|
37
37
|
from mindspore.common._stub_tensor import _convert_stub
|
|
38
|
+
from mindspore.ops.auto_generate.gen_ops_prim import select_ext_view_op, slice_ext_op, inplace_copy_op, \
|
|
39
|
+
index_op, inplace_index_put_op
|
|
38
40
|
|
|
39
41
|
slice_get_item = SliceGetItem()
|
|
40
42
|
hyper_map = base.HyperMap()
|
|
@@ -45,9 +47,15 @@ is_parameter = IsParameter()
|
|
|
45
47
|
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
46
48
|
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
47
49
|
|
|
48
|
-
|
|
50
|
+
select_view = SelectView()
|
|
49
51
|
copy_with_slice = CopyWithSlice()
|
|
50
52
|
|
|
53
|
+
tensor_1d = Tensor([0], dtype=mstype.int64)
|
|
54
|
+
empty_tensor_1d = Tensor(shape=(0,), dtype=mstype.int64)
|
|
55
|
+
empty_tensor_9d = Tensor(shape=(0,)*9, dtype=mstype.int64)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
51
59
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
52
60
|
new_axis_mask=0, shrink_axis_mask=0):
|
|
53
61
|
"""strided_slice primitive cache"""
|
|
@@ -148,7 +156,7 @@ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=N
|
|
|
148
156
|
elif transfer_type == ValueTransferType.kSelect:
|
|
149
157
|
data = F.select(Tensor(new_index), value, data)
|
|
150
158
|
elif transfer_type == ValueTransferType.kSelectView:
|
|
151
|
-
data =
|
|
159
|
+
data = select_view(data, arg[0], arg[1])
|
|
152
160
|
elif transfer_type == ValueTransferType.kCopyView:
|
|
153
161
|
value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
|
|
154
162
|
data = copy_with_slice(data, value)
|
|
@@ -196,14 +204,14 @@ def value_update(transfer_types, args, data, value):
|
|
|
196
204
|
return value
|
|
197
205
|
|
|
198
206
|
|
|
199
|
-
def
|
|
207
|
+
def _tensor_getitem_origin(self, index):
|
|
200
208
|
"""Handle tensor getitem"""
|
|
201
209
|
new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
|
|
202
210
|
self, index)
|
|
203
211
|
return data_update(tensor_update_types, tensor_update_args, self, new_index)
|
|
204
212
|
|
|
205
213
|
|
|
206
|
-
def
|
|
214
|
+
def _tensor_setitem_origin(self, index, value):
|
|
207
215
|
"""Handle tensor setitem"""
|
|
208
216
|
setitem_info = setitem_tensor_index_info(self, index, value)
|
|
209
217
|
new_index = setitem_info[0]
|
|
@@ -218,8 +226,213 @@ def _tensor_setitem(self, index, value):
|
|
|
218
226
|
return output
|
|
219
227
|
|
|
220
228
|
|
|
221
|
-
setattr(tensor_operator_registry, "
|
|
222
|
-
setattr(tensor_operator_registry, "
|
|
229
|
+
setattr(tensor_operator_registry, "_tensor_getitem_origin", _tensor_getitem_origin)
|
|
230
|
+
setattr(tensor_operator_registry, "_tensor_setitem_origin", _tensor_setitem_origin)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _record_tensor_index(index, remain_indexes, dim):
|
|
234
|
+
"""Record indexes remained to be used by aclnnIndex/aclnnIndexPut"""
|
|
235
|
+
if len(remain_indexes) > dim:
|
|
236
|
+
remain_indexes[dim] = index
|
|
237
|
+
return remain_indexes
|
|
238
|
+
|
|
239
|
+
while dim > len(remain_indexes):
|
|
240
|
+
# use empty_tensor with dim_num 9 to indicate unused dim
|
|
241
|
+
remain_indexes.append(empty_tensor_9d)
|
|
242
|
+
|
|
243
|
+
remain_indexes.append(index)
|
|
244
|
+
return remain_indexes
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _count_indexed_dims(indexes):
|
|
248
|
+
"""Count indexed dims"""
|
|
249
|
+
count = 0
|
|
250
|
+
for index in indexes:
|
|
251
|
+
if isinstance(index, Tensor):
|
|
252
|
+
if index.dtype == mstype.bool_:
|
|
253
|
+
count += index.ndim
|
|
254
|
+
else:
|
|
255
|
+
count += 1
|
|
256
|
+
elif not isinstance(index, (type(None), type(...), bool)):
|
|
257
|
+
count += 1
|
|
258
|
+
return count
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _do_select(self: Tensor, dim: int, index: int, dim_index: int, self_shape: list):
|
|
262
|
+
"""call select view operator"""
|
|
263
|
+
if not self_shape:
|
|
264
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
265
|
+
dim_size = self_shape[dim]
|
|
266
|
+
if index >= dim_size or index < -dim_size:
|
|
267
|
+
raise IndexError(f"Index {index} is out of bounds for dimension {dim_index} with size {dim_size}")
|
|
268
|
+
index = index + dim_size if index < 0 else index
|
|
269
|
+
return select_ext_view_op(self, dim, index)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _do_slice(self: Tensor, dim: int, index: slice, self_shape: list):
|
|
273
|
+
"""call slice view operator"""
|
|
274
|
+
def _get_index(index, default):
|
|
275
|
+
if index is None:
|
|
276
|
+
return default
|
|
277
|
+
if isinstance(index, Tensor):
|
|
278
|
+
return index.__index__()
|
|
279
|
+
return index
|
|
280
|
+
|
|
281
|
+
if not self_shape:
|
|
282
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
283
|
+
step = _get_index(index.step, 1)
|
|
284
|
+
if step <= 0:
|
|
285
|
+
raise ValueError("slice step must be positive")
|
|
286
|
+
start = _get_index(index.start, 0)
|
|
287
|
+
end = _get_index(index.stop, self_shape[dim])
|
|
288
|
+
if start == 0 and end == self_shape[dim] and step == 1:
|
|
289
|
+
return self
|
|
290
|
+
return slice_ext_op(self, dim, start, end, step)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _process_dim_in_multi_dim_index(prev_result, orig_tensor, index, dim, indexed_dims, dim_index, remain_indexes,
|
|
294
|
+
prev_shape):
|
|
295
|
+
"""Process dim in multi dim index"""
|
|
296
|
+
if isinstance(index, bool):
|
|
297
|
+
result = F.expand_dims(prev_result, dim)
|
|
298
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
299
|
+
_record_tensor_index(index_for_bool, remain_indexes, dim)
|
|
300
|
+
prev_shape.insert(dim, 1)
|
|
301
|
+
dim += 1
|
|
302
|
+
return result, dim, remain_indexes, prev_shape
|
|
303
|
+
if isinstance(index, int):
|
|
304
|
+
result = _do_select(prev_result, dim, index, dim_index, prev_shape)
|
|
305
|
+
del prev_shape[dim]
|
|
306
|
+
return result, dim, remain_indexes, prev_shape
|
|
307
|
+
if isinstance(index, slice):
|
|
308
|
+
result = _do_slice(prev_result, dim, index, prev_shape)
|
|
309
|
+
# current dim in prev_shape will not be used later, ignore it
|
|
310
|
+
dim += 1
|
|
311
|
+
return result, dim, remain_indexes, prev_shape
|
|
312
|
+
if isinstance(index, type(...)):
|
|
313
|
+
dim += (orig_tensor.ndim - indexed_dims)
|
|
314
|
+
return prev_result, dim, remain_indexes, prev_shape
|
|
315
|
+
if index is None:
|
|
316
|
+
result = F.expand_dims(prev_result, dim)
|
|
317
|
+
prev_shape.insert(dim, 1)
|
|
318
|
+
dim += 1
|
|
319
|
+
return result, dim, remain_indexes, prev_shape
|
|
320
|
+
if isinstance(index, Tensor):
|
|
321
|
+
result = prev_result
|
|
322
|
+
if index.ndim == 0 and index.dtype in mstype.int_type + mstype.uint_type + (mstype.bool_,):
|
|
323
|
+
if index.dtype in mstype.int_type + mstype.uint_type:
|
|
324
|
+
result = _do_select(prev_result, dim, index.item(), dim_index, prev_shape)
|
|
325
|
+
del prev_shape[dim]
|
|
326
|
+
return result, dim, remain_indexes, prev_shape
|
|
327
|
+
# process index with Tensor bool type
|
|
328
|
+
result = F.expand_dims(prev_result, dim)
|
|
329
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
330
|
+
_record_tensor_index(index_for_bool, remain_indexes, dim)
|
|
331
|
+
prev_shape.insert(dim, 1)
|
|
332
|
+
dim += 1
|
|
333
|
+
return result, dim, remain_indexes, prev_shape
|
|
334
|
+
_record_tensor_index(index, remain_indexes, dim)
|
|
335
|
+
dim += 1
|
|
336
|
+
return result, dim, remain_indexes, prev_shape
|
|
337
|
+
raise IndexError(f"Invalid tensor index type {index}")
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
|
|
341
|
+
"""Process indexes in tuple"""
|
|
342
|
+
self_viewed = self
|
|
343
|
+
self_viewed_shape = list(self.shape)
|
|
344
|
+
dim = 0
|
|
345
|
+
for i, index in enumerate(indexes):
|
|
346
|
+
if isinstance(index, (list, tuple, np.ndarray)):
|
|
347
|
+
index_np = np.array(index) if isinstance(index, (list, tuple)) else index
|
|
348
|
+
if index_np.dtype in (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
|
349
|
+
np.float16, np.float32, np.float64):
|
|
350
|
+
index = Tensor(index_np, mstype.int64)
|
|
351
|
+
elif index_np.dtype == np.bool_:
|
|
352
|
+
index = Tensor(index_np, mstype.bool_)
|
|
353
|
+
else:
|
|
354
|
+
raise TypeError(f"Index {index} contain unsupported elements")
|
|
355
|
+
self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index(
|
|
356
|
+
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape)
|
|
357
|
+
return self_viewed, remain_indexes
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _wrap_index_to_tuple(index):
|
|
361
|
+
"""Wrap index to tuple"""
|
|
362
|
+
if isinstance(index, tuple):
|
|
363
|
+
return index
|
|
364
|
+
if isinstance(index, list):
|
|
365
|
+
if len(index) < 32 and any(isinstance(i, (Tensor, list, tuple, slice, type(None), type(...))) for i in index):
|
|
366
|
+
return tuple(index)
|
|
367
|
+
return (index,)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _tensor_getitem(self, index):
|
|
371
|
+
"""Handle tensor getitem"""
|
|
372
|
+
if isinstance(index, bool):
|
|
373
|
+
self_viewed = F.expand_dims(self, 0)
|
|
374
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
375
|
+
return index_op(self_viewed, [index_for_bool])
|
|
376
|
+
if isinstance(index, int):
|
|
377
|
+
return _do_select(self, 0, index, 0, list(self.shape))
|
|
378
|
+
if isinstance(index, slice):
|
|
379
|
+
result = _do_slice(self, 0, index, list(self.shape))
|
|
380
|
+
return result
|
|
381
|
+
if index is None:
|
|
382
|
+
return F.expand_dims(self, 0)
|
|
383
|
+
if isinstance(index, type(...)):
|
|
384
|
+
return self
|
|
385
|
+
indexes = _wrap_index_to_tuple(index)
|
|
386
|
+
indexed_dims = _count_indexed_dims(indexes)
|
|
387
|
+
if self.ndim < indexed_dims:
|
|
388
|
+
raise IndexError(f"too many indices for tensor with dimension size {self.ndim}")
|
|
389
|
+
remain_indexes = []
|
|
390
|
+
self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
391
|
+
if not remain_indexes:
|
|
392
|
+
return self_viewed
|
|
393
|
+
return index_op(self_viewed, remain_indexes)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _tensor_setitem(self, index, value):
|
|
397
|
+
"""Handle tensor setitem"""
|
|
398
|
+
if not isinstance(value, Tensor):
|
|
399
|
+
if isinstance(value, (bool, int, float)):
|
|
400
|
+
value = Tensor(value, dtype=self.dtype)
|
|
401
|
+
else:
|
|
402
|
+
raise TypeError(f"Can't assign a {type(value)} to a {self.dtype}.")
|
|
403
|
+
|
|
404
|
+
if isinstance(index, bool) and index is False:
|
|
405
|
+
return self
|
|
406
|
+
if isinstance(index, type(...)):
|
|
407
|
+
inplace_copy_op(self, value)
|
|
408
|
+
return self
|
|
409
|
+
if index is None or (isinstance(index, bool) and index is True):
|
|
410
|
+
self_viewed = F.expand_dims(self, 0)
|
|
411
|
+
inplace_copy_op(self_viewed, value)
|
|
412
|
+
return self
|
|
413
|
+
if isinstance(index, int):
|
|
414
|
+
self_viewed = _do_select(self, 0, index, 0, list(self.shape))
|
|
415
|
+
inplace_copy_op(self_viewed, value)
|
|
416
|
+
return self
|
|
417
|
+
if isinstance(index, slice):
|
|
418
|
+
self_viewed = _do_slice(self, 0, index, list(self.shape))
|
|
419
|
+
inplace_copy_op(self_viewed, value)
|
|
420
|
+
return self
|
|
421
|
+
indexes = _wrap_index_to_tuple(index)
|
|
422
|
+
indexed_dims = _count_indexed_dims(indexes)
|
|
423
|
+
if self.ndim < indexed_dims:
|
|
424
|
+
raise IndexError(f"too many indices for tensor with dimension size {self.ndim}")
|
|
425
|
+
remain_indexes = []
|
|
426
|
+
self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
427
|
+
if not remain_indexes:
|
|
428
|
+
inplace_copy_op(self_viewed, value)
|
|
429
|
+
return self
|
|
430
|
+
inplace_index_put_op(self_viewed, remain_indexes, value)
|
|
431
|
+
return self
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
setattr(tensor_operator_registry, "_tensor_getitem", _tensor_getitem)
|
|
435
|
+
setattr(tensor_operator_registry, "_tensor_setitem", _tensor_setitem)
|
|
223
436
|
|
|
224
437
|
|
|
225
438
|
def _tensor_add(self, other):
|
|
@@ -313,31 +526,16 @@ def _check_scalar_tensor_args(args):
|
|
|
313
526
|
const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
|
|
314
527
|
|
|
315
528
|
|
|
316
|
-
def tensor_item(data
|
|
317
|
-
"""Tensor getitem
|
|
318
|
-
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
529
|
+
def tensor_item(data):
|
|
530
|
+
"""Tensor getitem which has only one element."""
|
|
319
531
|
if data.ndim == 0:
|
|
320
|
-
_check_scalar_tensor_args(args)
|
|
321
532
|
return TensorToScalar()(data)
|
|
322
|
-
if
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
return TensorToScalar()(data[0])
|
|
329
|
-
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
330
|
-
|
|
331
|
-
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
332
|
-
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
333
|
-
|
|
334
|
-
if len(args) == data.ndim:
|
|
335
|
-
return tensor_index_by_tuple(data, args)
|
|
336
|
-
if len(args) > 1:
|
|
337
|
-
const_utils.raise_value_error("Incorrect number of indices for array")
|
|
338
|
-
output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
|
|
339
|
-
return TensorToScalar()(output)
|
|
340
|
-
|
|
533
|
+
if data.shape == (1,):
|
|
534
|
+
return TensorToScalar()(data[0])
|
|
535
|
+
exp_msg = const_utils.gen_exception_msg("The tensor should have only one element. "
|
|
536
|
+
"But the shape of input tensor is {}.", data.shape)
|
|
537
|
+
const_utils.raise_value_error(exp_msg)
|
|
538
|
+
return None
|
|
341
539
|
|
|
342
540
|
def tensor_itemset(data, *args):
|
|
343
541
|
"""Tensor setitem by index and value."""
|
|
@@ -27,7 +27,6 @@ from mindspore.ops.operations import _inner_ops
|
|
|
27
27
|
from mindspore.ops.primitive import constexpr, _primexpr
|
|
28
28
|
from mindspore import log as logger
|
|
29
29
|
from mindspore import context
|
|
30
|
-
from mindspore._c_expression import Tensor as Tensor_
|
|
31
30
|
|
|
32
31
|
ALL_TENSOR = 0
|
|
33
32
|
NO_TENSOR = 1
|
|
@@ -201,7 +200,7 @@ def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
|
|
|
201
200
|
|
|
202
201
|
if isinstance(a, (list, tuple)):
|
|
203
202
|
if not a:
|
|
204
|
-
return
|
|
203
|
+
return Tensor(a, dtype)
|
|
205
204
|
# Convert all tuple/nested tuples to lists
|
|
206
205
|
a = _deep_list(a, dim_size)
|
|
207
206
|
# Convert all tensor sub-elements to numpy arrays
|
|
@@ -600,7 +600,8 @@ def _dict_add_dict(x, y):
|
|
|
600
600
|
hyper_add = base.HyperMap(_add_backward)
|
|
601
601
|
|
|
602
602
|
|
|
603
|
-
|
|
603
|
+
# pylint: disable=protected-access
|
|
604
|
+
@add._register_default()
|
|
604
605
|
def default_add(x, y):
|
|
605
606
|
"""Default function for add."""
|
|
606
607
|
return x + y
|
|
@@ -50,7 +50,8 @@ def _scalar_bitwise_and_tensor(x, y):
|
|
|
50
50
|
return F.bitwise_and(x, y)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
|
|
53
|
+
# pylint: disable=protected-access
|
|
54
|
+
@bitwise_and._register_default()
|
|
54
55
|
def default_bitwsie_add(x, y):
|
|
55
56
|
"""Default function for bitwise_and."""
|
|
56
57
|
return x & y
|
|
@@ -50,7 +50,8 @@ def _scalar_bitwise_or_tensor(x, y):
|
|
|
50
50
|
return F.bitwise_or(x, y)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
|
|
53
|
+
# pylint: disable=protected-access
|
|
54
|
+
@bitwise_or._register_default()
|
|
54
55
|
def default_bitwsie_or(x, y):
|
|
55
56
|
"""Default function for bitwise_or."""
|
|
56
57
|
return x | y
|
|
@@ -50,7 +50,8 @@ def _scalar_bitwise_xor_tensor(x, y):
|
|
|
50
50
|
return F.bitwise_xor(x, y)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
|
|
53
|
+
# pylint: disable=protected-access
|
|
54
|
+
@bitwise_xor._register_default()
|
|
54
55
|
def default_bitwsie_xor(x, y):
|
|
55
56
|
"""Default function for bitwise_xor."""
|
|
56
57
|
return x ^ y
|
|
@@ -21,6 +21,7 @@ from mindspore.ops.composite.multitype_ops._constexpr_utils import log_warning,
|
|
|
21
21
|
from mindspore.ops.composite import base
|
|
22
22
|
from mindspore.ops import functional as F
|
|
23
23
|
from mindspore.common import COOTensor
|
|
24
|
+
from mindspore.ops.auto_generate import div_op
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
div = base.MultitypeFuncGraph("div", True)
|
|
@@ -84,7 +85,7 @@ def _div_tensor(x, y):
|
|
|
84
85
|
Returns:
|
|
85
86
|
Tensor, has the same dtype as x.
|
|
86
87
|
"""
|
|
87
|
-
return
|
|
88
|
+
return div_op(x, y)
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
@div.register("Number", "Tensor")
|
|
@@ -99,7 +100,7 @@ def _scalar_div_tensor(x, y):
|
|
|
99
100
|
Returns:
|
|
100
101
|
Tensor, has the same dtype as x.
|
|
101
102
|
"""
|
|
102
|
-
return
|
|
103
|
+
return div_op(x, y)
|
|
103
104
|
|
|
104
105
|
|
|
105
106
|
@div.register("Tensor", "Number")
|
|
@@ -114,7 +115,7 @@ def _tensor_div_scalar(x, y):
|
|
|
114
115
|
Returns:
|
|
115
116
|
Tensor, has the same dtype as x.
|
|
116
117
|
"""
|
|
117
|
-
return
|
|
118
|
+
return div_op(x, y)
|
|
118
119
|
|
|
119
120
|
|
|
120
121
|
@div.register("Tuple", "Tensor")
|
|
@@ -181,7 +182,8 @@ def _tensor_div_list(x, y):
|
|
|
181
182
|
return F.tensor_div(x, y)
|
|
182
183
|
|
|
183
184
|
|
|
184
|
-
|
|
185
|
+
# pylint: disable=protected-access
|
|
186
|
+
@div._register_default()
|
|
185
187
|
def default_div(x, y):
|
|
186
188
|
"""Default function for div."""
|
|
187
189
|
if y != 0:
|
|
@@ -306,7 +306,7 @@ def _number_equal_string(x, y):
|
|
|
306
306
|
|
|
307
307
|
Args:
|
|
308
308
|
x (Number): The first input which is a number.
|
|
309
|
-
y (
|
|
309
|
+
y (str): The second input which is a string.
|
|
310
310
|
|
|
311
311
|
Returns:
|
|
312
312
|
bool, return false.
|
|
@@ -320,7 +320,7 @@ def _string_equal_number(x, y):
|
|
|
320
320
|
Determine if number equal string.
|
|
321
321
|
|
|
322
322
|
Args:
|
|
323
|
-
x (
|
|
323
|
+
x (str): The first input which is a string.
|
|
324
324
|
y (Number): The second input which is a number.
|
|
325
325
|
|
|
326
326
|
Returns:
|
|
@@ -329,7 +329,8 @@ def _string_equal_number(x, y):
|
|
|
329
329
|
return False
|
|
330
330
|
|
|
331
331
|
|
|
332
|
-
|
|
332
|
+
# pylint: disable=protected-access
|
|
333
|
+
@equal._register_default()
|
|
333
334
|
def default_equal(x, y):
|
|
334
335
|
"""Default function for equal."""
|
|
335
336
|
return x == y
|
|
@@ -80,7 +80,8 @@ def _tensor_floordiv_list(x, y):
|
|
|
80
80
|
return F.tensor_floordiv(x, y)
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
|
|
83
|
+
# pylint: disable=protected-access
|
|
84
|
+
@floordiv._register_default()
|
|
84
85
|
def default_floordiv(x, y):
|
|
85
86
|
"""Default function for floordiv."""
|
|
86
87
|
if y == 0:
|