mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +98 -21
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -3
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +31 -15
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +285 -191
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +11 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +100 -66
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +0 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/cell.py +126 -19
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +47 -13
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1659 -290
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +19 -16
- mindspore/ops/operations/custom_ops.py +11 -55
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +146 -308
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +24 -33
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -1
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +2 -6
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +31 -10
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -15
- mindspore/train/serialization.py +462 -76
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +362 -238
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
"""constexpr util"""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
from enum import IntEnum
|
|
19
|
+
import numpy as np
|
|
19
20
|
|
|
20
|
-
|
|
21
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
21
22
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
22
23
|
from mindspore.ops import functional as F
|
|
23
24
|
from mindspore.ops import operations as P
|
|
@@ -35,6 +36,8 @@ from mindspore import ops
|
|
|
35
36
|
from mindspore.ops.primitive import _primexpr
|
|
36
37
|
from mindspore import _checkparam as validator
|
|
37
38
|
from mindspore.common._stub_tensor import _convert_stub
|
|
39
|
+
from mindspore.ops.auto_generate.gen_ops_prim import select_ext_op, slice_ext_op, inplace_copy_op, \
|
|
40
|
+
index_op, inplace_index_put_op
|
|
38
41
|
|
|
39
42
|
slice_get_item = SliceGetItem()
|
|
40
43
|
hyper_map = base.HyperMap()
|
|
@@ -45,9 +48,15 @@ is_parameter = IsParameter()
|
|
|
45
48
|
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
46
49
|
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
47
50
|
|
|
48
|
-
|
|
51
|
+
select_view = SelectView()
|
|
49
52
|
copy_with_slice = CopyWithSlice()
|
|
50
53
|
|
|
54
|
+
tensor_1d = Tensor([0], dtype=mstype.int64)
|
|
55
|
+
empty_tensor_1d = Tensor_(shape=(0,), dtype=mstype.int64)
|
|
56
|
+
empty_tensor_9d = Tensor_(shape=(0,)*9, dtype=mstype.int64)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
51
60
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
52
61
|
new_axis_mask=0, shrink_axis_mask=0):
|
|
53
62
|
"""strided_slice primitive cache"""
|
|
@@ -148,7 +157,7 @@ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=N
|
|
|
148
157
|
elif transfer_type == ValueTransferType.kSelect:
|
|
149
158
|
data = F.select(Tensor(new_index), value, data)
|
|
150
159
|
elif transfer_type == ValueTransferType.kSelectView:
|
|
151
|
-
data =
|
|
160
|
+
data = select_view(data, arg[0], arg[1])
|
|
152
161
|
elif transfer_type == ValueTransferType.kCopyView:
|
|
153
162
|
value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
|
|
154
163
|
data = copy_with_slice(data, value)
|
|
@@ -196,14 +205,14 @@ def value_update(transfer_types, args, data, value):
|
|
|
196
205
|
return value
|
|
197
206
|
|
|
198
207
|
|
|
199
|
-
def
|
|
208
|
+
def _tensor_getitem_origin(self, index):
|
|
200
209
|
"""Handle tensor getitem"""
|
|
201
210
|
new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
|
|
202
211
|
self, index)
|
|
203
212
|
return data_update(tensor_update_types, tensor_update_args, self, new_index)
|
|
204
213
|
|
|
205
214
|
|
|
206
|
-
def
|
|
215
|
+
def _tensor_setitem_origin(self, index, value):
|
|
207
216
|
"""Handle tensor setitem"""
|
|
208
217
|
setitem_info = setitem_tensor_index_info(self, index, value)
|
|
209
218
|
new_index = setitem_info[0]
|
|
@@ -218,8 +227,213 @@ def _tensor_setitem(self, index, value):
|
|
|
218
227
|
return output
|
|
219
228
|
|
|
220
229
|
|
|
221
|
-
setattr(tensor_operator_registry, "
|
|
222
|
-
setattr(tensor_operator_registry, "
|
|
230
|
+
setattr(tensor_operator_registry, "_tensor_getitem_origin", _tensor_getitem_origin)
|
|
231
|
+
setattr(tensor_operator_registry, "_tensor_setitem_origin", _tensor_setitem_origin)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _record_tensor_index(index, remain_indexes, dim):
|
|
235
|
+
"""Record indexes remained to be used by aclnnIndex/aclnnIndexPut"""
|
|
236
|
+
if len(remain_indexes) > dim:
|
|
237
|
+
remain_indexes[dim] = index
|
|
238
|
+
return remain_indexes
|
|
239
|
+
|
|
240
|
+
while dim > len(remain_indexes):
|
|
241
|
+
# use empty_tensor with dim_num 9 to indicate unused dim
|
|
242
|
+
remain_indexes.append(empty_tensor_9d)
|
|
243
|
+
|
|
244
|
+
remain_indexes.append(index)
|
|
245
|
+
return remain_indexes
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _count_indexed_dims(indexes):
|
|
249
|
+
"""Count indexed dims"""
|
|
250
|
+
count = 0
|
|
251
|
+
for index in indexes:
|
|
252
|
+
if isinstance(index, Tensor):
|
|
253
|
+
if index.dtype == mstype.bool_:
|
|
254
|
+
count += index.ndim
|
|
255
|
+
else:
|
|
256
|
+
count += 1
|
|
257
|
+
elif not isinstance(index, (type(None), type(...), bool)):
|
|
258
|
+
count += 1
|
|
259
|
+
return count
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _do_select(self: Tensor, dim: int, index: int, dim_index: int, self_shape: list):
|
|
263
|
+
"""call select view operator"""
|
|
264
|
+
if not self_shape:
|
|
265
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
266
|
+
dim_size = self_shape[dim]
|
|
267
|
+
if index >= dim_size or index < -dim_size:
|
|
268
|
+
raise IndexError(f"Index {index} is out of bounds for dimension {dim_index} with size {dim_size}")
|
|
269
|
+
index = index + dim_size if index < 0 else index
|
|
270
|
+
return select_ext_op(self, dim, index)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _do_slice(self: Tensor, dim: int, index: slice, self_shape: list):
|
|
274
|
+
"""call slice view operator"""
|
|
275
|
+
def _get_index(index, default):
|
|
276
|
+
if index is None:
|
|
277
|
+
return default
|
|
278
|
+
if isinstance(index, Tensor):
|
|
279
|
+
return index.__index__()
|
|
280
|
+
return index
|
|
281
|
+
|
|
282
|
+
if not self_shape:
|
|
283
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
284
|
+
step = _get_index(index.step, 1)
|
|
285
|
+
if step <= 0:
|
|
286
|
+
raise ValueError("slice step must be positive")
|
|
287
|
+
start = _get_index(index.start, 0)
|
|
288
|
+
end = _get_index(index.stop, self_shape[dim])
|
|
289
|
+
if start == 0 and end == self_shape[dim] and step == 1:
|
|
290
|
+
return self
|
|
291
|
+
return slice_ext_op(self, dim, start, end, step)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _process_dim_in_multi_dim_index(prev_result, orig_tensor, index, dim, indexed_dims, dim_index, remain_indexes,
|
|
295
|
+
prev_shape):
|
|
296
|
+
"""Process dim in multi dim index"""
|
|
297
|
+
if isinstance(index, bool):
|
|
298
|
+
result = F.expand_dims(prev_result, dim)
|
|
299
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
300
|
+
_record_tensor_index(index_for_bool, remain_indexes, dim)
|
|
301
|
+
prev_shape.insert(dim, 1)
|
|
302
|
+
dim += 1
|
|
303
|
+
return result, dim, remain_indexes, prev_shape
|
|
304
|
+
if isinstance(index, int):
|
|
305
|
+
result = _do_select(prev_result, dim, index, dim_index, prev_shape)
|
|
306
|
+
del prev_shape[dim]
|
|
307
|
+
return result, dim, remain_indexes, prev_shape
|
|
308
|
+
if isinstance(index, slice):
|
|
309
|
+
result = _do_slice(prev_result, dim, index, prev_shape)
|
|
310
|
+
# current dim in prev_shape will not be used later, ignore it
|
|
311
|
+
dim += 1
|
|
312
|
+
return result, dim, remain_indexes, prev_shape
|
|
313
|
+
if isinstance(index, type(...)):
|
|
314
|
+
dim += (orig_tensor.ndim - indexed_dims)
|
|
315
|
+
return prev_result, dim, remain_indexes, prev_shape
|
|
316
|
+
if index is None:
|
|
317
|
+
result = F.expand_dims(prev_result, dim)
|
|
318
|
+
prev_shape.insert(dim, 1)
|
|
319
|
+
dim += 1
|
|
320
|
+
return result, dim, remain_indexes, prev_shape
|
|
321
|
+
if isinstance(index, Tensor):
|
|
322
|
+
result = prev_result
|
|
323
|
+
if index.ndim == 0 and index.dtype in mstype.int_type + mstype.uint_type + (mstype.bool_,):
|
|
324
|
+
if index.dtype in mstype.int_type + mstype.uint_type:
|
|
325
|
+
result = _do_select(prev_result, dim, index.item(), dim_index, prev_shape)
|
|
326
|
+
del prev_shape[dim]
|
|
327
|
+
return result, dim, remain_indexes, prev_shape
|
|
328
|
+
# process index with Tensor bool type
|
|
329
|
+
result = F.expand_dims(prev_result, dim)
|
|
330
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
331
|
+
_record_tensor_index(index_for_bool, remain_indexes, dim)
|
|
332
|
+
prev_shape.insert(dim, 1)
|
|
333
|
+
dim += 1
|
|
334
|
+
return result, dim, remain_indexes, prev_shape
|
|
335
|
+
_record_tensor_index(index, remain_indexes, dim)
|
|
336
|
+
dim += 1
|
|
337
|
+
return result, dim, remain_indexes, prev_shape
|
|
338
|
+
raise IndexError(f"Invalid tensor index type {index}")
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
|
|
342
|
+
"""Process indexes in tuple"""
|
|
343
|
+
self_viewed = self
|
|
344
|
+
self_viewed_shape = list(self.shape)
|
|
345
|
+
dim = 0
|
|
346
|
+
for i, index in enumerate(indexes):
|
|
347
|
+
if isinstance(index, (list, tuple, np.ndarray)):
|
|
348
|
+
index_np = np.array(index) if isinstance(index, (list, tuple)) else index
|
|
349
|
+
if index_np.dtype in (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
|
350
|
+
np.float16, np.float32, np.float64):
|
|
351
|
+
index = Tensor(index_np, mstype.int64)
|
|
352
|
+
elif index_np.dtype == np.bool_:
|
|
353
|
+
index = Tensor(index_np, mstype.bool_)
|
|
354
|
+
else:
|
|
355
|
+
raise TypeError(f"Index {index} contain unsupported elements")
|
|
356
|
+
self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index(
|
|
357
|
+
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape)
|
|
358
|
+
return self_viewed, remain_indexes
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _wrap_index_to_tuple(index):
|
|
362
|
+
"""Wrap index to tuple"""
|
|
363
|
+
if isinstance(index, tuple):
|
|
364
|
+
return index
|
|
365
|
+
if isinstance(index, list):
|
|
366
|
+
if len(index) < 32 and any(isinstance(i, (Tensor, list, tuple, slice, type(None), type(...))) for i in index):
|
|
367
|
+
return tuple(index)
|
|
368
|
+
return (index,)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _tensor_getitem(self, index):
|
|
372
|
+
"""Handle tensor getitem"""
|
|
373
|
+
if isinstance(index, bool):
|
|
374
|
+
self_viewed = F.expand_dims(self, 0)
|
|
375
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
376
|
+
return index_op(self_viewed, [index_for_bool])
|
|
377
|
+
if isinstance(index, int):
|
|
378
|
+
return _do_select(self, 0, index, 0, list(self.shape))
|
|
379
|
+
if isinstance(index, slice):
|
|
380
|
+
result = _do_slice(self, 0, index, list(self.shape))
|
|
381
|
+
return result
|
|
382
|
+
if index is None:
|
|
383
|
+
return F.expand_dims(self, 0)
|
|
384
|
+
if isinstance(index, type(...)):
|
|
385
|
+
return self
|
|
386
|
+
indexes = _wrap_index_to_tuple(index)
|
|
387
|
+
indexed_dims = _count_indexed_dims(indexes)
|
|
388
|
+
if self.ndim < indexed_dims:
|
|
389
|
+
raise IndexError(f"too many indices for tensor with dimension size {self.ndim}")
|
|
390
|
+
remain_indexes = []
|
|
391
|
+
self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
392
|
+
if not remain_indexes:
|
|
393
|
+
return self_viewed
|
|
394
|
+
return index_op(self_viewed, remain_indexes)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _tensor_setitem(self, index, value):
|
|
398
|
+
"""Handle tensor setitem"""
|
|
399
|
+
if not isinstance(value, Tensor):
|
|
400
|
+
if isinstance(value, (bool, int, float)):
|
|
401
|
+
value = Tensor(value, dtype=self.dtype)
|
|
402
|
+
else:
|
|
403
|
+
raise TypeError(f"Can't assign a {type(value)} to a {self.dtype}.")
|
|
404
|
+
|
|
405
|
+
if isinstance(index, bool) and index is False:
|
|
406
|
+
return self
|
|
407
|
+
if isinstance(index, type(...)):
|
|
408
|
+
inplace_copy_op(self, value)
|
|
409
|
+
return self
|
|
410
|
+
if index is None or (isinstance(index, bool) and index is True):
|
|
411
|
+
self_viewed = F.expand_dims(self, 0)
|
|
412
|
+
inplace_copy_op(self_viewed, value)
|
|
413
|
+
return self
|
|
414
|
+
if isinstance(index, int):
|
|
415
|
+
self_viewed = _do_select(self, 0, index, 0, list(self.shape))
|
|
416
|
+
inplace_copy_op(self_viewed, value)
|
|
417
|
+
return self
|
|
418
|
+
if isinstance(index, slice):
|
|
419
|
+
self_viewed = _do_slice(self, 0, index, list(self.shape))
|
|
420
|
+
inplace_copy_op(self_viewed, value)
|
|
421
|
+
return self
|
|
422
|
+
indexes = _wrap_index_to_tuple(index)
|
|
423
|
+
indexed_dims = _count_indexed_dims(indexes)
|
|
424
|
+
if self.ndim < indexed_dims:
|
|
425
|
+
raise IndexError(f"too many indices for tensor with dimension size {self.ndim}")
|
|
426
|
+
remain_indexes = []
|
|
427
|
+
self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
428
|
+
if not remain_indexes:
|
|
429
|
+
inplace_copy_op(self_viewed, value)
|
|
430
|
+
return self
|
|
431
|
+
inplace_index_put_op(self_viewed, remain_indexes, value)
|
|
432
|
+
return self
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
setattr(tensor_operator_registry, "_tensor_getitem", _tensor_getitem)
|
|
436
|
+
setattr(tensor_operator_registry, "_tensor_setitem", _tensor_setitem)
|
|
223
437
|
|
|
224
438
|
|
|
225
439
|
def _tensor_add(self, other):
|
|
@@ -313,31 +527,16 @@ def _check_scalar_tensor_args(args):
|
|
|
313
527
|
const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
|
|
314
528
|
|
|
315
529
|
|
|
316
|
-
def tensor_item(data
|
|
317
|
-
"""Tensor getitem
|
|
318
|
-
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
530
|
+
def tensor_item(data):
|
|
531
|
+
"""Tensor getitem which has only one element."""
|
|
319
532
|
if data.ndim == 0:
|
|
320
|
-
_check_scalar_tensor_args(args)
|
|
321
533
|
return TensorToScalar()(data)
|
|
322
|
-
if
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
return TensorToScalar()(data[0])
|
|
329
|
-
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
330
|
-
|
|
331
|
-
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
332
|
-
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
333
|
-
|
|
334
|
-
if len(args) == data.ndim:
|
|
335
|
-
return tensor_index_by_tuple(data, args)
|
|
336
|
-
if len(args) > 1:
|
|
337
|
-
const_utils.raise_value_error("Incorrect number of indices for array")
|
|
338
|
-
output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
|
|
339
|
-
return TensorToScalar()(output)
|
|
340
|
-
|
|
534
|
+
if data.shape == (1,):
|
|
535
|
+
return TensorToScalar()(data[0])
|
|
536
|
+
exp_msg = const_utils.gen_exception_msg("The tensor should have only one element. "
|
|
537
|
+
"But the shape of input tensor is {}.", data.shape)
|
|
538
|
+
const_utils.raise_value_error(exp_msg)
|
|
539
|
+
return None
|
|
341
540
|
|
|
342
541
|
def tensor_itemset(data, *args):
|
|
343
542
|
"""Tensor setitem by index and value."""
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
"""Implementation for internal polymorphism `pow` operations."""
|
|
17
17
|
|
|
18
18
|
from __future__ import absolute_import
|
|
19
|
-
from mindspore.ops.composite.multitype_ops import _compile_utils as utils
|
|
20
19
|
from mindspore.ops.composite import base
|
|
21
20
|
from mindspore.ops import functional as F
|
|
22
21
|
|
|
@@ -53,34 +52,6 @@ def _scalar_pow_tensor(x, y):
|
|
|
53
52
|
return F.tensor_pow(x, y)
|
|
54
53
|
|
|
55
54
|
|
|
56
|
-
@pow_.register("Tuple", "Tensor")
|
|
57
|
-
def _tuple_pow_tensor(x, y):
|
|
58
|
-
"""Returns x ** y where x is a tuple and y is a tensor. """
|
|
59
|
-
x = utils.sequence_to_tensor(x, y.dtype)
|
|
60
|
-
return F.tensor_pow(x, y)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
@pow_.register("Tensor", "Tuple")
|
|
64
|
-
def _tensor_pow_tuple(x, y):
|
|
65
|
-
"""Returns x ** y where x is a tensor and y is a tuple. """
|
|
66
|
-
y = utils.sequence_to_tensor(y, x.dtype)
|
|
67
|
-
return F.tensor_pow(x, y)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
@pow_.register("List", "Tensor")
|
|
71
|
-
def _list_pow_tensor(x, y):
|
|
72
|
-
"""Returns x ** y where x is a list and y is a tensor. """
|
|
73
|
-
x = utils.sequence_to_tensor(x, y.dtype)
|
|
74
|
-
return F.tensor_pow(x, y)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
@pow_.register("Tensor", "List")
|
|
78
|
-
def _tensor_pow_list(x, y):
|
|
79
|
-
"""Returns x ** y where x is a tensor and y is a list. """
|
|
80
|
-
y = utils.sequence_to_tensor(y, x.dtype)
|
|
81
|
-
return F.tensor_pow(x, y)
|
|
82
|
-
|
|
83
|
-
|
|
84
55
|
@pow_.register_default()
|
|
85
56
|
def default_pow(x, y):
|
|
86
57
|
"""Default function for pow."""
|
|
@@ -167,6 +167,9 @@ from .array_func import (
|
|
|
167
167
|
top_k,
|
|
168
168
|
deepcopy,
|
|
169
169
|
arange_ext,
|
|
170
|
+
view_as,
|
|
171
|
+
type_as,
|
|
172
|
+
expand_as,
|
|
170
173
|
)
|
|
171
174
|
from .parameter_func import (
|
|
172
175
|
assign,
|
|
@@ -317,6 +320,7 @@ from .math_func import (
|
|
|
317
320
|
sinh,
|
|
318
321
|
cosh,
|
|
319
322
|
tanh,
|
|
323
|
+
tanh_,
|
|
320
324
|
tanhshrink,
|
|
321
325
|
asinh,
|
|
322
326
|
arcsinh,
|
|
@@ -379,6 +383,7 @@ from .math_func import (
|
|
|
379
383
|
atleast_1d,
|
|
380
384
|
dstack,
|
|
381
385
|
diff,
|
|
386
|
+
diff_ext,
|
|
382
387
|
atleast_2d,
|
|
383
388
|
cartesian_prod,
|
|
384
389
|
atleast_3d,
|
|
@@ -400,6 +405,7 @@ from .math_func import (
|
|
|
400
405
|
remainder,
|
|
401
406
|
remainder_ext,
|
|
402
407
|
iou,
|
|
408
|
+
rotated_iou,
|
|
403
409
|
bmm,
|
|
404
410
|
trapz,
|
|
405
411
|
cholesky,
|
|
@@ -443,6 +449,7 @@ from .math_func import (
|
|
|
443
449
|
tensor_dot,
|
|
444
450
|
vecdot,
|
|
445
451
|
dot,
|
|
452
|
+
isnan_ext,
|
|
446
453
|
batch_dot,
|
|
447
454
|
eps,
|
|
448
455
|
)
|
|
@@ -460,6 +467,7 @@ from .nn_func import (
|
|
|
460
467
|
max_pool2d,
|
|
461
468
|
max_pool3d,
|
|
462
469
|
batch_norm,
|
|
470
|
+
add_rms_norm,
|
|
463
471
|
rms_norm,
|
|
464
472
|
bidense,
|
|
465
473
|
celu,
|
|
@@ -485,6 +493,8 @@ from .nn_func import (
|
|
|
485
493
|
soft_shrink,
|
|
486
494
|
is_floating_point,
|
|
487
495
|
incre_flash_attention,
|
|
496
|
+
prompt_flash_attention,
|
|
497
|
+
flash_attention_score,
|
|
488
498
|
intopk,
|
|
489
499
|
interpolate,
|
|
490
500
|
upsample,
|
|
@@ -531,6 +541,7 @@ from .nn_func import (
|
|
|
531
541
|
sigmoid,
|
|
532
542
|
logsigmoid,
|
|
533
543
|
relu,
|
|
544
|
+
relu_,
|
|
534
545
|
relu6,
|
|
535
546
|
rrelu,
|
|
536
547
|
conv3d,
|
|
@@ -723,6 +734,7 @@ from .clip_func import (
|
|
|
723
734
|
clip_by_value,
|
|
724
735
|
clip_by_norm,
|
|
725
736
|
clamp,
|
|
737
|
+
clamp_,
|
|
726
738
|
clip,
|
|
727
739
|
clip_by_global_norm,
|
|
728
740
|
)
|