mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/numpy/array_ops.py
CHANGED
|
@@ -19,8 +19,6 @@ import operator
|
|
|
19
19
|
|
|
20
20
|
from mindspore.common import dtype as mstype
|
|
21
21
|
from mindspore.common import Tensor, mutable
|
|
22
|
-
from mindspore.ops import operations as P
|
|
23
|
-
from mindspore.ops import functional as F
|
|
24
22
|
from mindspore.ops.primitive import constexpr, _primexpr
|
|
25
23
|
from mindspore.nn import Cell
|
|
26
24
|
from mindspore import ops
|
|
@@ -73,17 +71,24 @@ def expand_dims(a, axis):
|
|
|
73
71
|
if not isinstance(axis, (int, tuple, list)):
|
|
74
72
|
_raise_type_error("axis must be tuple, list or int, but got ", axis)
|
|
75
73
|
if isinstance(axis, int):
|
|
76
|
-
return
|
|
74
|
+
return ops.expand_dims(a, axis)
|
|
77
75
|
ndim = a.ndim + len(axis)
|
|
78
76
|
axis = _canonicalize_axis(axis, ndim)
|
|
79
77
|
for ax in axis:
|
|
80
|
-
a =
|
|
78
|
+
a = ops.expand_dims(a, ax)
|
|
81
79
|
return a
|
|
82
80
|
|
|
83
81
|
|
|
84
82
|
def squeeze(a, axis=None):
|
|
85
83
|
"""
|
|
86
|
-
|
|
84
|
+
Return the Tensor after deleting the dimension of size 1 in the specified `axis`.
|
|
85
|
+
|
|
86
|
+
If :math:`axis=None`, it will remove all the dimensions of size 1.
|
|
87
|
+
If `axis` is specified, it will remove the dimensions of size 1 in the given `axis`.
|
|
88
|
+
For example, if the dimension is not specified :math:`axis=None`, input shape is (A, 1, B, C, 1, D),
|
|
89
|
+
then the shape of the output Tensor is (A, B, C, D). If the dimension is specified, the squeeze operation
|
|
90
|
+
is only performed in the specified dimension. If input shape is (A, 1, B), when :math:`axis=0` or :math:`axis=2`,
|
|
91
|
+
the input tensor is not changed, while when :math:`axis=1`, the input tensor shape is changed to (A, B).
|
|
87
92
|
|
|
88
93
|
Args:
|
|
89
94
|
a (Tensor): Input tensor array.
|
|
@@ -199,14 +204,14 @@ def rollaxis(x, axis, start=0):
|
|
|
199
204
|
if not isinstance(start, int):
|
|
200
205
|
_raise_type_error("integer argument expected, but got ", start)
|
|
201
206
|
|
|
202
|
-
shape =
|
|
203
|
-
ndim =
|
|
207
|
+
shape = ops.shape(x)
|
|
208
|
+
ndim = ops.tuple_len(shape)
|
|
204
209
|
|
|
205
210
|
axis = _check_axes_range(axis, ndim)
|
|
206
211
|
start = _check_start_normalize(start, ndim)
|
|
207
212
|
if start - axis >= 0 and start - axis <= 1:
|
|
208
213
|
return x
|
|
209
|
-
perm =
|
|
214
|
+
perm = ops.make_range(0, ndim)
|
|
210
215
|
new_perm = None
|
|
211
216
|
if start < axis:
|
|
212
217
|
if axis + 1 < ndim:
|
|
@@ -222,7 +227,7 @@ def rollaxis(x, axis, start=0):
|
|
|
222
227
|
new_perm = perm[0:axis] + perm[axis + 1:start] + \
|
|
223
228
|
perm[axis:axis + 1]
|
|
224
229
|
|
|
225
|
-
return
|
|
230
|
+
return ops.transpose(x, new_perm)
|
|
226
231
|
|
|
227
232
|
|
|
228
233
|
def swapaxes(x, axis1, axis2):
|
|
@@ -409,7 +414,7 @@ def concatenate(arrays, axis=0):
|
|
|
409
414
|
# as: tuple(tensor_1(4,5), tensor_2(4,5), tensor_3(4,5))
|
|
410
415
|
if axis is None or axis >= MAX_NUMPY_DIMS:
|
|
411
416
|
return ravel(arrays)
|
|
412
|
-
arr_shape =
|
|
417
|
+
arr_shape = ops.shape(arrays)
|
|
413
418
|
_check_axes_range((axis,), len(arr_shape))
|
|
414
419
|
# move axis 0 to the disiganated position, while keep other axes' relative
|
|
415
420
|
# positions unchanged
|
|
@@ -424,12 +429,12 @@ def concatenate(arrays, axis=0):
|
|
|
424
429
|
flattened_arrays += (ravel(arr),)
|
|
425
430
|
axis = -1
|
|
426
431
|
flattened_arrays = _promote_type_for_concatenate(flattened_arrays)
|
|
427
|
-
return
|
|
432
|
+
return ops.Concat(axis)(flattened_arrays)
|
|
428
433
|
|
|
429
434
|
# convert a list of tensor to a tuple of tensor
|
|
430
435
|
arrays = _convert_list_tensor_to_tuple_tensor(arrays)
|
|
431
436
|
|
|
432
|
-
arr_shape =
|
|
437
|
+
arr_shape = ops.shape(arrays[0])
|
|
433
438
|
_check_axes_range((axis,), len(arr_shape))
|
|
434
439
|
|
|
435
440
|
# if only one tensor in the tuple/list, return the tensor itself
|
|
@@ -437,7 +442,7 @@ def concatenate(arrays, axis=0):
|
|
|
437
442
|
return arrays[0]
|
|
438
443
|
|
|
439
444
|
arrays = _promote_type_for_concatenate(arrays)
|
|
440
|
-
return
|
|
445
|
+
return ops.Concat(axis)(arrays)
|
|
441
446
|
|
|
442
447
|
|
|
443
448
|
def append(arr, values, axis=None):
|
|
@@ -476,7 +481,7 @@ def append(arr, values, axis=None):
|
|
|
476
481
|
values = values.ravel()
|
|
477
482
|
else:
|
|
478
483
|
_check_axis_in_range(axis, arr.ndim)
|
|
479
|
-
if
|
|
484
|
+
if ops.rank(arr) != ops.rank(values):
|
|
480
485
|
_raise_value_error("all tensors must have same number of dimensions")
|
|
481
486
|
return concatenate((arr, values), axis)
|
|
482
487
|
|
|
@@ -518,13 +523,13 @@ def column_stack(tup):
|
|
|
518
523
|
trans_tup = ()
|
|
519
524
|
for tensor in tup:
|
|
520
525
|
if tensor.ndim < 1:
|
|
521
|
-
tensor =
|
|
526
|
+
tensor = ops.expand_dims(tensor, 0)
|
|
522
527
|
if tensor.ndim == 1:
|
|
523
|
-
tensor =
|
|
528
|
+
tensor = ops.expand_dims(tensor, 1)
|
|
524
529
|
trans_tup += (tensor,)
|
|
525
530
|
if not trans_tup:
|
|
526
531
|
_raise_value_error("Need at least one tensor to concatenate.")
|
|
527
|
-
return
|
|
532
|
+
return ops.Concat(1)(trans_tup)
|
|
528
533
|
|
|
529
534
|
|
|
530
535
|
def vstack(tup):
|
|
@@ -568,7 +573,7 @@ def vstack(tup):
|
|
|
568
573
|
trans_tup += (tensor,)
|
|
569
574
|
if not trans_tup:
|
|
570
575
|
_raise_value_error("Need at least one tensor to concatenate.")
|
|
571
|
-
return
|
|
576
|
+
return ops.Concat(0)(trans_tup)
|
|
572
577
|
|
|
573
578
|
|
|
574
579
|
def hstack(tup):
|
|
@@ -608,13 +613,13 @@ def hstack(tup):
|
|
|
608
613
|
tuple_of_tensor = ()
|
|
609
614
|
for tensor in tup:
|
|
610
615
|
if tensor.ndim < 1:
|
|
611
|
-
tensor =
|
|
616
|
+
tensor = ops.expand_dims(tensor, 0)
|
|
612
617
|
tuple_of_tensor += (tensor,)
|
|
613
618
|
if not tuple_of_tensor:
|
|
614
619
|
_raise_value_error("Need at least one tensor to concatenate.")
|
|
615
620
|
if tuple_of_tensor[0].ndim <= 1:
|
|
616
|
-
return
|
|
617
|
-
return
|
|
621
|
+
return ops.Concat(0)(tuple_of_tensor)
|
|
622
|
+
return ops.Concat(1)(tuple_of_tensor)
|
|
618
623
|
|
|
619
624
|
|
|
620
625
|
def dstack(tup):
|
|
@@ -658,11 +663,11 @@ def dstack(tup):
|
|
|
658
663
|
if tensor.ndim <= 1:
|
|
659
664
|
tensor = _expand(tensor, 2, 0)
|
|
660
665
|
if tensor.ndim == 2:
|
|
661
|
-
tensor =
|
|
666
|
+
tensor = ops.expand_dims(tensor, 2)
|
|
662
667
|
trans_tup += (tensor,)
|
|
663
668
|
if not trans_tup:
|
|
664
669
|
_raise_value_error("Need at least one tensor to concatenate.")
|
|
665
|
-
return
|
|
670
|
+
return ops.Concat(2)(trans_tup)
|
|
666
671
|
|
|
667
672
|
|
|
668
673
|
def where(condition, x=None, y=None):
|
|
@@ -705,42 +710,42 @@ def where(condition, x=None, y=None):
|
|
|
705
710
|
"""
|
|
706
711
|
condition, x, y = _to_tensor(condition, x, y)
|
|
707
712
|
# type promotes input tensors
|
|
708
|
-
dtype1 =
|
|
709
|
-
dtype2 =
|
|
713
|
+
dtype1 = ops.dtype(x)
|
|
714
|
+
dtype2 = ops.dtype(y)
|
|
710
715
|
dtype = _promote(dtype1, dtype2)
|
|
711
716
|
if not _check_same_type(dtype1, dtype):
|
|
712
|
-
x =
|
|
717
|
+
x = ops.cast(x, dtype)
|
|
713
718
|
if not _check_same_type(dtype2, dtype):
|
|
714
|
-
y =
|
|
719
|
+
y = ops.cast(y, dtype)
|
|
715
720
|
is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type(dtype2, mstype.bool_)
|
|
716
721
|
if is_bool:
|
|
717
722
|
# select does not support bool type for x or y
|
|
718
|
-
x =
|
|
719
|
-
y =
|
|
723
|
+
x = ops.cast(x, mstype.float32)
|
|
724
|
+
y = ops.cast(y, mstype.float32)
|
|
720
725
|
|
|
721
|
-
dynamic =
|
|
722
|
-
or
|
|
726
|
+
dynamic = ops.is_sequence_value_unknown(ops.shape(condition)) or ops.is_sequence_value_unknown(ops.shape(x))\
|
|
727
|
+
or ops.is_sequence_value_unknown(ops.shape(y))
|
|
723
728
|
# As select op currently does not support broadcast, broadcasts input tensors
|
|
724
729
|
if not dynamic:
|
|
725
|
-
shape_out = _infer_out_shape(
|
|
726
|
-
|
|
730
|
+
shape_out = _infer_out_shape(ops.shape(condition),
|
|
731
|
+
ops.shape(x), ops.shape(y))
|
|
727
732
|
condition = _broadcast_to_shape(condition, shape_out)
|
|
728
733
|
x = _broadcast_to_shape(x, shape_out)
|
|
729
734
|
y = _broadcast_to_shape(y, shape_out)
|
|
730
735
|
else:
|
|
731
736
|
# Get the broadcast shape through broadcast calculation
|
|
732
737
|
add_x_y = x + y
|
|
733
|
-
add_out = condition +
|
|
734
|
-
shape_out =
|
|
738
|
+
add_out = condition + ops.cast(add_x_y, condition.dtype)
|
|
739
|
+
shape_out = ops.Shape()(add_out)
|
|
735
740
|
condition = ops.broadcast_to(condition, shape_out)
|
|
736
741
|
x = ops.broadcast_to(x, shape_out)
|
|
737
742
|
y = ops.broadcast_to(y, shape_out)
|
|
738
743
|
|
|
739
|
-
if not _check_same_type(
|
|
740
|
-
condition =
|
|
741
|
-
res =
|
|
744
|
+
if not _check_same_type(ops.dtype(condition), mstype.bool_):
|
|
745
|
+
condition = ops.cast(condition, mstype.bool_)
|
|
746
|
+
res = ops.select(condition, x, y)
|
|
742
747
|
if is_bool:
|
|
743
|
-
res =
|
|
748
|
+
res = ops.cast(res, mstype.bool_)
|
|
744
749
|
return res
|
|
745
750
|
|
|
746
751
|
|
|
@@ -873,13 +878,13 @@ def atleast_3d(*arys):
|
|
|
873
878
|
"""
|
|
874
879
|
res = []
|
|
875
880
|
for arr in arys:
|
|
876
|
-
ndim =
|
|
881
|
+
ndim = ops.rank(arr)
|
|
877
882
|
if ndim == 0:
|
|
878
|
-
arr =
|
|
883
|
+
arr = ops.reshape(arr, (1, 1, 1))
|
|
879
884
|
elif ndim == 1:
|
|
880
|
-
arr =
|
|
885
|
+
arr = ops.reshape(arr, (1, ops.size(arr), 1))
|
|
881
886
|
elif ndim == 2:
|
|
882
|
-
arr =
|
|
887
|
+
arr = ops.reshape(arr, ops.shape(arr) + (1,))
|
|
883
888
|
res.append(arr)
|
|
884
889
|
if len(res) == 1:
|
|
885
890
|
return res[0]
|
|
@@ -927,24 +932,24 @@ def stack(arrays, axis=0):
|
|
|
927
932
|
"""
|
|
928
933
|
|
|
929
934
|
if isinstance(arrays, Tensor):
|
|
930
|
-
shape =
|
|
931
|
-
ndim =
|
|
935
|
+
shape = ops.shape(arrays)
|
|
936
|
+
ndim = ops.rank(arrays)
|
|
932
937
|
axis = axis % ndim
|
|
933
|
-
axes =
|
|
938
|
+
axes = ops.make_range(ndim)
|
|
934
939
|
perm = axes[1:axis + 1] + (0,) + axes[axis + 1:]
|
|
935
940
|
if _is_shape_empty(shape):
|
|
936
941
|
return _empty(mstype.float32, shape[1:axis + 1] + (shape[0],) + shape[axis + 1:])
|
|
937
942
|
return transpose(arrays, perm)
|
|
938
943
|
|
|
939
944
|
if isinstance(arrays, (list, tuple)):
|
|
940
|
-
shape = (len(arrays),) +
|
|
945
|
+
shape = (len(arrays),) + ops.shape(arrays[0])
|
|
941
946
|
ndim = len(shape)
|
|
942
947
|
axis = axis % ndim
|
|
943
948
|
if _is_shape_empty(shape):
|
|
944
949
|
return _empty(mstype.float32, shape[1:axis + 1] + (shape[0],) + shape[axis + 1:])
|
|
945
950
|
seq = ()
|
|
946
951
|
for arr in arrays:
|
|
947
|
-
seq += (
|
|
952
|
+
seq += (ops.expand_dims(arr, axis),)
|
|
948
953
|
return concatenate(seq, axis)
|
|
949
954
|
return _raise_value_error('input arrays must be Tensor, tuple, or list')
|
|
950
955
|
|
|
@@ -954,7 +959,7 @@ class UniqueNet(Cell):
|
|
|
954
959
|
|
|
955
960
|
def __init__(self):
|
|
956
961
|
super(UniqueNet, self).__init__()
|
|
957
|
-
self.unique =
|
|
962
|
+
self.unique = ops.Unique()
|
|
958
963
|
|
|
959
964
|
def construct(self, x):
|
|
960
965
|
return self.unique(x)
|
|
@@ -998,7 +1003,7 @@ def unique(x, return_inverse=False):
|
|
|
998
1003
|
value= [0, 1, 1, 1, 2, 3, 4]))
|
|
999
1004
|
"""
|
|
1000
1005
|
_check_input_tensor(x)
|
|
1001
|
-
if
|
|
1006
|
+
if ops.tuple_len(ops.shape(x)) > 1:
|
|
1002
1007
|
x = ravel(x)
|
|
1003
1008
|
uniq = UniqueNet()
|
|
1004
1009
|
res = uniq(x)
|
|
@@ -1032,7 +1037,7 @@ def roll_along_axis(a, shift, axis):
|
|
|
1032
1037
|
end1 = ()
|
|
1033
1038
|
end2 = ()
|
|
1034
1039
|
stride = _list_comprehensions(a.ndim, 1, True)
|
|
1035
|
-
for i in
|
|
1040
|
+
for i in ops.make_range(a.ndim):
|
|
1036
1041
|
if i != axis:
|
|
1037
1042
|
begin1 += (0,)
|
|
1038
1043
|
end1 += (a.shape[i],)
|
|
@@ -1043,8 +1048,8 @@ def roll_along_axis(a, shift, axis):
|
|
|
1043
1048
|
end1 += (a.shape[i],)
|
|
1044
1049
|
begin2 += (0,)
|
|
1045
1050
|
end2 += (shift,)
|
|
1046
|
-
return append(
|
|
1047
|
-
|
|
1051
|
+
return append(ops.strided_slice(a, begin1, end1, stride),
|
|
1052
|
+
ops.strided_slice(a, begin2, end2, stride), axis=axis)
|
|
1048
1053
|
|
|
1049
1054
|
|
|
1050
1055
|
def roll(a, shift, axis=None):
|
|
@@ -1086,7 +1091,7 @@ def roll(a, shift, axis=None):
|
|
|
1086
1091
|
original_shape = a.shape
|
|
1087
1092
|
original_dtype = a.dtype
|
|
1088
1093
|
restore_shape = False
|
|
1089
|
-
#
|
|
1094
|
+
# ops.strided_slice only supports float on cpu, this will change once more supports
|
|
1090
1095
|
# are added.
|
|
1091
1096
|
if not _check_is_float(original_dtype):
|
|
1092
1097
|
if not original_dtype in (mstype.complex64, mstype.complex128):
|
|
@@ -1181,14 +1186,14 @@ def moveaxis(a, source, destination):
|
|
|
1181
1186
|
>>> print(output.shape)
|
|
1182
1187
|
(5, 4, 3)
|
|
1183
1188
|
"""
|
|
1184
|
-
ndim =
|
|
1189
|
+
ndim = ops.rank(a)
|
|
1185
1190
|
source = _check_axis_valid(source, ndim)
|
|
1186
1191
|
destination = _check_axis_valid(destination, ndim)
|
|
1187
1192
|
if len(source) != len(destination):
|
|
1188
1193
|
_raise_value_error('`source` and `destination` arguments must have the same number of elements')
|
|
1189
1194
|
perm = _get_moved_perm(ndim, source, destination)
|
|
1190
1195
|
|
|
1191
|
-
return
|
|
1196
|
+
return ops.transpose(a, perm)
|
|
1192
1197
|
|
|
1193
1198
|
|
|
1194
1199
|
def tile(a, reps):
|
|
@@ -1233,13 +1238,13 @@ def tile(a, reps):
|
|
|
1233
1238
|
[[0 1 2 0 1 2]]]
|
|
1234
1239
|
"""
|
|
1235
1240
|
_check_input_tensor(a)
|
|
1236
|
-
ndim =
|
|
1237
|
-
shape =
|
|
1241
|
+
ndim = ops.rank(a)
|
|
1242
|
+
shape = ops.shape(a)
|
|
1238
1243
|
reps = _add_unit_axes(reps, ndim)
|
|
1239
1244
|
if _is_shape_empty(shape) or _is_shape_empty(reps):
|
|
1240
1245
|
shape = _add_unit_axes(shape, len(reps))
|
|
1241
|
-
return _empty(
|
|
1242
|
-
return
|
|
1246
|
+
return _empty(ops.dtype(a), _seq_prod(shape, reps))
|
|
1247
|
+
return ops.tile(a, reps)
|
|
1243
1248
|
|
|
1244
1249
|
|
|
1245
1250
|
@_primexpr
|
|
@@ -1284,7 +1289,7 @@ def broadcast_to(array, shape):
|
|
|
1284
1289
|
def _check(shape_a, shape):
|
|
1285
1290
|
if not _check_can_broadcast_to(shape_a, shape):
|
|
1286
1291
|
_raise_value_error('cannot broadcast with ', shape)
|
|
1287
|
-
shape_a =
|
|
1292
|
+
shape_a = ops.shape(array)
|
|
1288
1293
|
_check(shape_a, shape)
|
|
1289
1294
|
return _broadcast_to_shape(array, shape)
|
|
1290
1295
|
|
|
@@ -1322,7 +1327,7 @@ def broadcast_arrays(*args):
|
|
|
1322
1327
|
[[4, 4, 4],
|
|
1323
1328
|
[5, 5, 5]])]
|
|
1324
1329
|
"""
|
|
1325
|
-
shapes = map(
|
|
1330
|
+
shapes = map(ops.shape, args)
|
|
1326
1331
|
out_shape = _infer_out_shape(*shapes)
|
|
1327
1332
|
res = []
|
|
1328
1333
|
for arr in args:
|
|
@@ -1439,18 +1444,18 @@ def _split(x, indices_or_sections, opname, axis=0):
|
|
|
1439
1444
|
if indices_or_sections > length_along_dim:
|
|
1440
1445
|
_raise_value_error("empty tensor encountered.")
|
|
1441
1446
|
if opname == "split" or length_along_dim % indices_or_sections == 0:
|
|
1442
|
-
res =
|
|
1447
|
+
res = ops.Split(axis_new, indices_or_sections)(x)
|
|
1443
1448
|
else:
|
|
1444
1449
|
num_long_tensor = length_along_dim % indices_or_sections
|
|
1445
1450
|
num_short_tensor = indices_or_sections - num_long_tensor
|
|
1446
1451
|
length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1)
|
|
1447
1452
|
length2 = length_along_dim - length1
|
|
1448
|
-
start1 = _list_comprehensions(
|
|
1453
|
+
start1 = _list_comprehensions(ops.rank(x), 0, True)
|
|
1449
1454
|
size1 = _tuple_setitem(arr_shape, axis_new, length1)
|
|
1450
1455
|
start2 = _tuple_setitem(start1, axis_new, length1)
|
|
1451
1456
|
size2 = _tuple_setitem(arr_shape, axis_new, length2)
|
|
1452
|
-
res =
|
|
1453
|
-
|
|
1457
|
+
res = ops.Split(axis_new, num_long_tensor)(ops.tensor_slice(x, start1, size1)) + \
|
|
1458
|
+
ops.Split(axis_new, num_short_tensor)(ops.tensor_slice(x, start2, size2))
|
|
1454
1459
|
|
|
1455
1460
|
elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections):
|
|
1456
1461
|
res = _split_sub_tensors(x, indices_or_sections, axis_new)
|
|
@@ -1487,7 +1492,7 @@ def _split_sub_tensors(x, indices, axis):
|
|
|
1487
1492
|
end[axis] = idx
|
|
1488
1493
|
if end[axis] <= begin[axis]:
|
|
1489
1494
|
_raise_value_error("empty sub-tensor encountered.")
|
|
1490
|
-
sliced_tensor =
|
|
1495
|
+
sliced_tensor = ops.strided_slice(x, _type_convert(tuple, begin), _type_convert(tuple, end), strides)
|
|
1491
1496
|
sub_tensors.append(sliced_tensor)
|
|
1492
1497
|
return sub_tensors
|
|
1493
1498
|
|
|
@@ -1679,10 +1684,10 @@ def flip(m, axis=None):
|
|
|
1679
1684
|
[3. 2.]]]
|
|
1680
1685
|
"""
|
|
1681
1686
|
_check_input_tensor(m)
|
|
1682
|
-
ndim =
|
|
1687
|
+
ndim = ops.rank(m)
|
|
1683
1688
|
axes = _check_axis_valid(axis, ndim)
|
|
1684
|
-
shape =
|
|
1685
|
-
dtype =
|
|
1689
|
+
shape = ops.shape(m)
|
|
1690
|
+
dtype = ops.dtype(m)
|
|
1686
1691
|
if _is_shape_empty(shape):
|
|
1687
1692
|
return m
|
|
1688
1693
|
if not _check_is_float(dtype):
|
|
@@ -1690,9 +1695,9 @@ def flip(m, axis=None):
|
|
|
1690
1695
|
start = _get_flip_start(ndim, shape, axes)
|
|
1691
1696
|
end = _get_flip_end(ndim, shape, axes)
|
|
1692
1697
|
strides = _get_flip_strides(ndim, axes)
|
|
1693
|
-
res =
|
|
1694
|
-
if not _check_same_type(
|
|
1695
|
-
res =
|
|
1698
|
+
res = ops.strided_slice(m, start, end, strides)
|
|
1699
|
+
if not _check_same_type(ops.dtype(res), dtype):
|
|
1700
|
+
res = ops.cast(res, dtype)
|
|
1696
1701
|
return res
|
|
1697
1702
|
|
|
1698
1703
|
|
|
@@ -1796,49 +1801,49 @@ def take_along_axis(arr, indices, axis):
|
|
|
1796
1801
|
if axis is None:
|
|
1797
1802
|
arr = ravel(arr)
|
|
1798
1803
|
axis = 0
|
|
1799
|
-
ndim =
|
|
1800
|
-
if ndim !=
|
|
1804
|
+
ndim = ops.rank(arr)
|
|
1805
|
+
if ndim != ops.rank(indices):
|
|
1801
1806
|
_raise_value_error('`indices` and `arr` must have the same number of dimensions')
|
|
1802
1807
|
axis = _check_axis_in_range(axis, ndim)
|
|
1803
1808
|
|
|
1804
|
-
shape_arr =
|
|
1805
|
-
shape_indices =
|
|
1809
|
+
shape_arr = ops.shape(arr)
|
|
1810
|
+
shape_indices = ops.shape(indices)
|
|
1806
1811
|
# broadcasts indices against the shape of arr except at axis
|
|
1807
1812
|
indices = _broadcast_to(indices, _tuple_slice(shape_indices, None, axis),
|
|
1808
1813
|
_tuple_slice(shape_arr, None, axis), ndim)
|
|
1809
1814
|
indices = _broadcast_to(indices, _tuple_slice(shape_arr, None, axis + 1) +
|
|
1810
1815
|
_tuple_slice(shape_indices, axis + 1, None), shape_arr, ndim)
|
|
1811
1816
|
arr = _broadcast_to(arr, shape_arr, indices.shape, ndim)
|
|
1812
|
-
return
|
|
1817
|
+
return ops.gather_d(arr, axis, indices)
|
|
1813
1818
|
|
|
1814
1819
|
|
|
1815
1820
|
def _mod(x, y):
|
|
1816
1821
|
"""Computes x mod y."""
|
|
1817
|
-
quotient =
|
|
1818
|
-
prod =
|
|
1819
|
-
return
|
|
1822
|
+
quotient = ops.tensor_floordiv(x, y)
|
|
1823
|
+
prod = ops.tensor_mul(y, quotient)
|
|
1824
|
+
return ops.tensor_sub(x, prod)
|
|
1820
1825
|
|
|
1821
1826
|
|
|
1822
1827
|
def _check_indices(dims, indices, mode, allow_negative_index=True):
|
|
1823
1828
|
"""Checks whether indices are out of bounds."""
|
|
1824
|
-
shape =
|
|
1825
|
-
dtype =
|
|
1829
|
+
shape = ops.shape(indices)
|
|
1830
|
+
dtype = ops.dtype(indices)
|
|
1826
1831
|
if not allow_negative_index:
|
|
1827
|
-
lowerbounds =
|
|
1832
|
+
lowerbounds = ops.fill(dtype, shape, 0)
|
|
1828
1833
|
else:
|
|
1829
|
-
lowerbounds =
|
|
1830
|
-
upperbounds =
|
|
1831
|
-
out_of_lowerbounds =
|
|
1832
|
-
out_of_upperbounds =
|
|
1834
|
+
lowerbounds = ops.fill(dtype, shape, -dims)
|
|
1835
|
+
upperbounds = ops.fill(dtype, shape, dims - 1)
|
|
1836
|
+
out_of_lowerbounds = ops.tensor_lt(indices, lowerbounds)
|
|
1837
|
+
out_of_upperbounds = ops.tensor_gt(indices, upperbounds)
|
|
1833
1838
|
if mode == 'raise':
|
|
1834
1839
|
_raise_unimplemented_error('"raise" mode is not implemented')
|
|
1835
1840
|
if mode == 'wrap':
|
|
1836
|
-
return _mod(indices,
|
|
1841
|
+
return _mod(indices, ops.fill(mstype.float32, shape, dims)).astype(dtype)
|
|
1837
1842
|
if mode != 'clip':
|
|
1838
1843
|
_raise_value_error('invalid mode. Expected "raise", "wrap", or "clip"')
|
|
1839
|
-
zeros =
|
|
1840
|
-
clipped =
|
|
1841
|
-
clipped =
|
|
1844
|
+
zeros = ops.fill(dtype, shape, 0)
|
|
1845
|
+
clipped = ops.select(out_of_lowerbounds, zeros, indices)
|
|
1846
|
+
clipped = ops.select(out_of_upperbounds, upperbounds, clipped)
|
|
1842
1847
|
return clipped
|
|
1843
1848
|
|
|
1844
1849
|
|
|
@@ -1940,7 +1945,7 @@ def repeat(a, repeats, axis=None):
|
|
|
1940
1945
|
[3 4]]
|
|
1941
1946
|
"""
|
|
1942
1947
|
a = _to_tensor(a)
|
|
1943
|
-
return a.
|
|
1948
|
+
return a.repeat_interleave(repeats, axis)
|
|
1944
1949
|
|
|
1945
1950
|
|
|
1946
1951
|
def rot90(a, k=1, axes=(0, 1)):
|
|
@@ -2052,9 +2057,9 @@ def select(condlist, choicelist, default=0):
|
|
|
2052
2057
|
[ 0 1 2 0 16]
|
|
2053
2058
|
"""
|
|
2054
2059
|
condlist, choicelist = _to_tensor(condlist, choicelist)
|
|
2055
|
-
shape_cond =
|
|
2056
|
-
shape_choice =
|
|
2057
|
-
if
|
|
2060
|
+
shape_cond = ops.shape(condlist)
|
|
2061
|
+
shape_choice = ops.shape(choicelist)
|
|
2062
|
+
if ops.rank(condlist) == 0 or ops.rank(choicelist) == 0:
|
|
2058
2063
|
_raise_value_error('input cannot be scalars')
|
|
2059
2064
|
case_num = shape_cond[0]
|
|
2060
2065
|
if shape_choice[0] != case_num:
|
|
@@ -2066,25 +2071,25 @@ def select(condlist, choicelist, default=0):
|
|
|
2066
2071
|
case_size = _infer_out_shape(case_size_cond, case_size_choice)
|
|
2067
2072
|
shape_broadcasted = (case_num,) + case_size
|
|
2068
2073
|
ndim = len(shape_broadcasted)
|
|
2069
|
-
shape_cond_expanded = ((case_num,) + _list_comprehensions(ndim -
|
|
2074
|
+
shape_cond_expanded = ((case_num,) + _list_comprehensions(ndim - ops.rank(condlist), 1, True) +
|
|
2070
2075
|
case_size_cond)
|
|
2071
|
-
condlist = _broadcast_to_shape(
|
|
2072
|
-
shape_choice_expanded = ((case_num,) + _list_comprehensions(ndim -
|
|
2076
|
+
condlist = _broadcast_to_shape(ops.reshape(condlist, shape_cond_expanded), shape_broadcasted)
|
|
2077
|
+
shape_choice_expanded = ((case_num,) + _list_comprehensions(ndim - ops.rank(choicelist), 1, True) +
|
|
2073
2078
|
case_size_choice)
|
|
2074
|
-
choicelist = _broadcast_to_shape(
|
|
2079
|
+
choicelist = _broadcast_to_shape(ops.reshape(choicelist, shape_choice_expanded), shape_broadcasted)
|
|
2075
2080
|
|
|
2076
2081
|
slice_start = _list_comprehensions(ndim - 1, 0, True)
|
|
2077
2082
|
slice_size = (1,) + case_size
|
|
2078
|
-
dtype =
|
|
2083
|
+
dtype = ops.dtype(choicelist)
|
|
2079
2084
|
if isinstance(default, Tensor):
|
|
2080
|
-
default_slice = default.astype(
|
|
2085
|
+
default_slice = default.astype(ops.dtype(choicelist)).reshape(slice_size)
|
|
2081
2086
|
else:
|
|
2082
|
-
default_slice =
|
|
2087
|
+
default_slice = ops.fill(ops.dtype(choicelist), slice_size, default)
|
|
2083
2088
|
for i in range(case_num - 1, -1, -1):
|
|
2084
|
-
cond_slice =
|
|
2085
|
-
choice_slice =
|
|
2086
|
-
default_slice =
|
|
2087
|
-
return
|
|
2089
|
+
cond_slice = ops.tensor_slice(condlist.astype(mstype.float32), (i,) + slice_start, slice_size)
|
|
2090
|
+
choice_slice = ops.tensor_slice(choicelist, (i,) + slice_start, slice_size)
|
|
2091
|
+
default_slice = ops.select(cond_slice.astype(mstype.bool_), choice_slice, default_slice)
|
|
2092
|
+
return ops.reshape(default_slice, (case_size)).astype(dtype)
|
|
2088
2093
|
|
|
2089
2094
|
|
|
2090
2095
|
@_primexpr
|
|
@@ -2173,32 +2178,32 @@ def choose(a, choices, mode='clip'):
|
|
|
2173
2178
|
[ 10 -10 10]]
|
|
2174
2179
|
"""
|
|
2175
2180
|
a = _to_tensor(a)
|
|
2176
|
-
if not _check_is_int(
|
|
2181
|
+
if not _check_is_int(ops.dtype(a)):
|
|
2177
2182
|
_raise_value_error('`a` should be an int array')
|
|
2178
2183
|
if isinstance(choices, (tuple, list)):
|
|
2179
2184
|
# broadcasts choices to the same shape if choices is a sequence
|
|
2180
2185
|
choices = _to_tensor(*choices)
|
|
2181
2186
|
shapes = ()
|
|
2182
2187
|
for choice in choices:
|
|
2183
|
-
shapes += (
|
|
2184
|
-
shape_choice = _infer_out_shape(
|
|
2188
|
+
shapes += (ops.shape(choice),)
|
|
2189
|
+
shape_choice = _infer_out_shape(ops.shape(a), *shapes)
|
|
2185
2190
|
tmp = []
|
|
2186
2191
|
for choice in choices:
|
|
2187
2192
|
tmp.append(broadcast_to(choice, shape_choice))
|
|
2188
2193
|
choices = stack(tmp)
|
|
2189
2194
|
else:
|
|
2190
2195
|
choices = _to_tensor(choices)
|
|
2191
|
-
shape_choice = _infer_out_shape(
|
|
2192
|
-
choices =
|
|
2193
|
-
choices = broadcast_to(choices, (
|
|
2196
|
+
shape_choice = _infer_out_shape(ops.shape(a), ops.shape(choices)[1:])
|
|
2197
|
+
choices = ops.reshape(choices, choices.shape[:1] + _add_unit_axes(choices.shape[1:], len(shape_choice)))
|
|
2198
|
+
choices = broadcast_to(choices, (ops.shape(choices)[0],) + shape_choice)
|
|
2194
2199
|
|
|
2195
|
-
if
|
|
2200
|
+
if ops.rank(a) == 0 or ops.rank(choices) == 0:
|
|
2196
2201
|
_raise_value_error('input cannot be scalars')
|
|
2197
2202
|
a = broadcast_to(a, shape_choice)
|
|
2198
|
-
a = _check_indices(
|
|
2199
|
-
grid = _get_grid(
|
|
2200
|
-
indices = concatenate((a.reshape(
|
|
2201
|
-
return
|
|
2203
|
+
a = _check_indices(ops.shape(choices)[0], a, mode, allow_negative_index=False)
|
|
2204
|
+
grid = _get_grid(ops.shape(a))
|
|
2205
|
+
indices = concatenate((a.reshape(ops.shape(a) + (1,)), grid), -1)
|
|
2206
|
+
return ops.gather_nd(choices, indices)
|
|
2202
2207
|
|
|
2203
2208
|
|
|
2204
2209
|
def size(a, axis=None):
|
|
@@ -2312,24 +2317,24 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
|
|
2312
2317
|
[0 8 0]
|
|
2313
2318
|
[0 0 9]]]
|
|
2314
2319
|
"""
|
|
2315
|
-
ndim =
|
|
2316
|
-
shape =
|
|
2320
|
+
ndim = ops.rank(arr)
|
|
2321
|
+
shape = ops.shape(arr)
|
|
2317
2322
|
axis = _check_axis_in_range(axis, ndim)
|
|
2318
2323
|
arr = moveaxis(arr, axis, -1)
|
|
2319
|
-
arr =
|
|
2324
|
+
arr = ops.reshape(arr, (-1, ops.shape(arr)[-1]))
|
|
2320
2325
|
slices = []
|
|
2321
|
-
for i in range(
|
|
2326
|
+
for i in range(ops.shape(arr)[0]):
|
|
2322
2327
|
slices.append(func1d(arr[i], *args, **kwargs))
|
|
2323
2328
|
stacked_slices = stack(slices)
|
|
2324
2329
|
shape_stacked = (_tuple_slice(shape, None, axis) + _tuple_slice(shape, axis + 1, None) +
|
|
2325
|
-
_tuple_slice(
|
|
2326
|
-
res =
|
|
2330
|
+
_tuple_slice(ops.shape(stacked_slices), 1, None))
|
|
2331
|
+
res = ops.reshape(stacked_slices, shape_stacked)
|
|
2327
2332
|
|
|
2328
2333
|
# moves the dimensions returned by `func1d` back to `axis`
|
|
2329
|
-
ndim_func =
|
|
2334
|
+
ndim_func = ops.rank(res) - ndim + 1
|
|
2330
2335
|
if ndim_func >= 1:
|
|
2331
|
-
res = moveaxis(res,
|
|
2332
|
-
|
|
2336
|
+
res = moveaxis(res, ops.make_range(ndim - 1, ops.rank(res)),
|
|
2337
|
+
ops.make_range(axis, axis + ndim_func))
|
|
2333
2338
|
return res
|
|
2334
2339
|
|
|
2335
2340
|
|
|
@@ -2445,17 +2450,17 @@ def unravel_index(indices, shape, order='C'):
|
|
|
2445
2450
|
_raise_value_error('invalid order. Expected "C" or "F"')
|
|
2446
2451
|
if isinstance(shape, int):
|
|
2447
2452
|
shape = (shape,)
|
|
2448
|
-
ndim =
|
|
2453
|
+
ndim = ops.rank(indices)
|
|
2449
2454
|
if order == 'F':
|
|
2450
2455
|
sizes = _cumprod(shape)
|
|
2451
2456
|
else:
|
|
2452
2457
|
sizes = _cumprod(shape[::-1])
|
|
2453
2458
|
sizes = _to_tensor(sizes[::-1] + (1,))
|
|
2454
|
-
sizes =
|
|
2459
|
+
sizes = ops.reshape(sizes, (-1,) + _list_comprehensions(ndim, 1, True))
|
|
2455
2460
|
total_size = sizes[0]
|
|
2456
2461
|
indices = where(indices > total_size - 1, total_size - 1, indices)
|
|
2457
2462
|
if _get_device() == 'GPU':
|
|
2458
|
-
dtype =
|
|
2463
|
+
dtype = ops.dtype(total_size)
|
|
2459
2464
|
lowerbounds = (-(total_size.astype(mstype.float32))).astype(dtype)
|
|
2460
2465
|
else:
|
|
2461
2466
|
lowerbounds = -total_size
|
|
@@ -2515,7 +2520,7 @@ def apply_over_axes(func, a, axes):
|
|
|
2515
2520
|
res = a
|
|
2516
2521
|
for axis in axes:
|
|
2517
2522
|
res = func(res, axis=axis)
|
|
2518
|
-
res =
|
|
2523
|
+
res = ops.expand_dims(res, axis) if res.ndim != a.ndim else res
|
|
2519
2524
|
if res.ndim != a.ndim:
|
|
2520
2525
|
_raise_value_error("function is not returning a tensor of the correct shape")
|
|
2521
2526
|
return res
|
|
@@ -2546,7 +2551,7 @@ def argwhere(a):
|
|
|
2546
2551
|
Tensor(shape=[2, 3], dtype=Int64, value=[[0, 0, 0], [0, 1, 0]])
|
|
2547
2552
|
"""
|
|
2548
2553
|
a = _to_tensor(a)
|
|
2549
|
-
return
|
|
2554
|
+
return ops.argwhere(a)
|
|
2550
2555
|
|
|
2551
2556
|
|
|
2552
2557
|
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
|
@@ -2584,42 +2589,42 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
|
|
2584
2589
|
def unique_w_ind(arr):
|
|
2585
2590
|
array, sort_indices = arr.ravel().sort()
|
|
2586
2591
|
array_type = array.dtype
|
|
2587
|
-
cmp_array1 =
|
|
2588
|
-
cmp_array2 =
|
|
2592
|
+
cmp_array1 = ops.cat((array, Tensor([0], dtype=array_type)))
|
|
2593
|
+
cmp_array2 = ops.cat((Tensor([0], dtype=array_type), array))
|
|
2589
2594
|
mask = cmp_array1 != cmp_array2
|
|
2590
2595
|
mask[0] = True
|
|
2591
|
-
array =
|
|
2592
|
-
ind =
|
|
2596
|
+
array = ops.masked_select(array, mask[:-1])
|
|
2597
|
+
ind = ops.masked_select(sort_indices, mask[:-1])
|
|
2593
2598
|
return array, ind
|
|
2594
2599
|
|
|
2595
2600
|
if not isinstance(assume_unique, bool) or not isinstance(return_indices, bool):
|
|
2596
2601
|
_raise_type_error("assume_unique or return_indices is not bool type.")
|
|
2597
2602
|
ar1, ar2 = _to_tensor(ar1, ar2)
|
|
2598
|
-
ind1 =
|
|
2599
|
-
ind2 =
|
|
2603
|
+
ind1 = ops.fill(mstype.int32, (ar1.size,), -1)
|
|
2604
|
+
ind2 = ops.fill(mstype.int32, (ar2.size,), -1)
|
|
2600
2605
|
if not assume_unique:
|
|
2601
2606
|
if return_indices:
|
|
2602
2607
|
array1, ind1 = unique_w_ind(ar1)
|
|
2603
2608
|
array2, ind2 = unique_w_ind(ar2)
|
|
2604
2609
|
else:
|
|
2605
|
-
array1 =
|
|
2606
|
-
array2 =
|
|
2610
|
+
array1 = ops.unique(ar1)[0]
|
|
2611
|
+
array2 = ops.unique(ar2)[0]
|
|
2607
2612
|
else:
|
|
2608
2613
|
array1 = ar1.ravel()
|
|
2609
2614
|
array2 = ar2.ravel()
|
|
2610
2615
|
concat_array = concatenate((array1, array2))
|
|
2611
2616
|
if return_indices:
|
|
2612
|
-
concat_sort_indices =
|
|
2617
|
+
concat_sort_indices = ops.argsort(concat_array)
|
|
2613
2618
|
concat_array = concat_array[concat_sort_indices]
|
|
2614
2619
|
else:
|
|
2615
2620
|
concat_array, concat_sort_indices = concat_array.sort()
|
|
2616
2621
|
|
|
2617
2622
|
mask_res = concat_array[1:] == concat_array[:-1]
|
|
2618
|
-
res =
|
|
2623
|
+
res = ops.masked_select(concat_array[1:], mask_res)
|
|
2619
2624
|
|
|
2620
2625
|
if return_indices:
|
|
2621
|
-
ar1_indices =
|
|
2622
|
-
ar2_indices =
|
|
2626
|
+
ar1_indices = ops.masked_select(concat_sort_indices[:-1], mask_res)
|
|
2627
|
+
ar2_indices = ops.masked_select(concat_sort_indices[1:], mask_res)
|
|
2623
2628
|
if ar2_indices.shape[0] > 0:
|
|
2624
2629
|
ar2_indices = ar2_indices - array1.size
|
|
2625
2630
|
if not assume_unique:
|