mindspore 2.6.0__cp39-cp39-win_amd64.whl → 2.7.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +40 -9
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -1
- mindspore/_extends/parse/parser.py +36 -61
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +32 -13
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/amp.py +4 -4
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/common/__init__.py +27 -2
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +76 -15
- mindspore/common/api.py +193 -112
- mindspore/common/dtype.py +21 -11
- mindspore/common/dump.py +10 -15
- mindspore/common/generator.py +2 -3
- mindspore/common/hook_handle.py +11 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/parameter.py +26 -12
- mindspore/common/recompute.py +3 -3
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +48 -83
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +38 -23
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/datasets.py +20 -7
- mindspore/dataset/engine/datasets_user_defined.py +32 -2
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +7 -3
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -5
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +1 -0
- mindspore/include/api/cell.h +37 -1
- mindspore/include/api/delegate.h +10 -0
- mindspore/include/api/model.h +3 -0
- mindspore/include/api/types.h +2 -2
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +60 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +4 -44
- mindspore/mint/distributed/__init__.py +1 -0
- mindspore/mint/distributed/distributed.py +208 -5
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +164 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +122 -98
- mindspore/mint/nn/layer/normalization.py +8 -22
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/nn/cell.py +325 -499
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +32 -34
- mindspore/nn/layer/basic.py +67 -64
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +86 -85
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +37 -39
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +34 -37
- mindspore/nn/wrap/grad_reducer.py +37 -37
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +5 -5
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_vmap/vmap_array_ops.py +6 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +17 -8
- mindspore/ops/auto_generate/gen_extend_func.py +1 -51
- mindspore/ops/auto_generate/gen_ops_def.py +463 -257
- mindspore/ops/auto_generate/gen_ops_prim.py +1127 -885
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +8 -4
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +3 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +7 -94
- mindspore/ops/function/debug_func.py +4 -3
- mindspore/ops/function/grad/grad_func.py +1 -1
- mindspore/ops/function/math_func.py +21 -367
- mindspore/ops/function/nn_func.py +26 -41
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +31 -4
- mindspore/ops/functional.py +0 -2
- mindspore/ops/functional_overload.py +463 -6
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_custom_ops_utils.py +675 -8
- mindspore/ops/operations/_inner_ops.py +3 -6
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/comm_ops.py +185 -26
- mindspore/ops/operations/custom_ops.py +235 -172
- mindspore/ops/operations/debug_ops.py +55 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +15 -16
- mindspore/ops/operations/math_ops.py +3 -4
- mindspore/ops/operations/nn_ops.py +5 -6
- mindspore/ops/primitive.py +6 -10
- mindspore/ops/tensor_method.py +36 -4
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +7 -2
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +4 -2
- mindspore/parallel/_cell_wrapper.py +106 -40
- mindspore/parallel/_parallel_serialization.py +1 -1
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +13 -8
- mindspore/parallel/auto_parallel.py +12 -5
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +3 -1
- mindspore/parallel/cluster/process_entity/_api.py +84 -48
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +43 -4
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +1 -1
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +9 -10
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -6
- mindspore/parallel/shard.py +2 -2
- mindspore/parallel/transform_safetensors.py +462 -174
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +42 -22
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/profiler_context.py +25 -27
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_op_analyse.py +235 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +305 -314
- mindspore/profiler/envprofiler.py +12 -7
- mindspore/profiler/experimental_config.py +96 -6
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/npu_profiler.py +29 -19
- mindspore/profiler/profiler.py +35 -19
- mindspore/profiler/profiler_action_controller.py +64 -76
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +5 -5
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +60 -45
- mindspore/runtime/memory.py +21 -30
- mindspore/runtime/thread_bind_core.py +298 -164
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +6 -2
- mindspore/train/amp.py +43 -20
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_checkpoint.py +3 -6
- mindspore/train/callback/_flops_collector.py +1 -1
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +71 -13
- mindspore/train/data_sink.py +11 -2
- mindspore/train/dataset_helper.py +9 -0
- mindspore/train/model.py +51 -33
- mindspore/train/serialization.py +133 -111
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +162 -78
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +6 -9
- mindspore/version.py +1 -1
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +5 -4
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +329 -367
- mindspore/_deprecated/jit.py +0 -198
- mindspore/experimental/es/__init__.py +0 -22
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,9 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
|
-
from mindspore.common._stub_tensor import _convert_stub
|
|
17
16
|
from mindspore.ops._utils.arg_handler import *
|
|
18
17
|
from mindspore._c_expression import AdaptiveMaxPool2DPrim_
|
|
18
|
+
from mindspore._c_expression import ApplyRotaryPosEmbPrim_
|
|
19
19
|
from mindspore._c_expression import ArgMaxWithValuePrim_
|
|
20
20
|
from mindspore._c_expression import ArgMinWithValuePrim_
|
|
21
21
|
from mindspore._c_expression import BatchMatMulPrim_
|
|
@@ -27,6 +27,7 @@ from mindspore._c_expression import BroadcastToPrim_
|
|
|
27
27
|
from mindspore._c_expression import ConcatPrim_
|
|
28
28
|
from mindspore._c_expression import CrossPrim_
|
|
29
29
|
from mindspore._c_expression import CummaxPrim_
|
|
30
|
+
from mindspore._c_expression import DiagonalViewPrim_
|
|
30
31
|
from mindspore._c_expression import EluExtPrim_
|
|
31
32
|
from mindspore._c_expression import FFNExtPrim_
|
|
32
33
|
from mindspore._c_expression import FlashAttentionScoreGradPrim_
|
|
@@ -53,6 +54,7 @@ from mindspore._c_expression import NanToNumPrim_
|
|
|
53
54
|
from mindspore._c_expression import NLLLossGradPrim_
|
|
54
55
|
from mindspore._c_expression import NLLLossPrim_
|
|
55
56
|
from mindspore._c_expression import OneHotExtPrim_
|
|
57
|
+
from mindspore._c_expression import PagedAttentionPrim_
|
|
56
58
|
from mindspore._c_expression import PromptFlashAttentionPrim_
|
|
57
59
|
from mindspore._c_expression import ReduceAllPrim_
|
|
58
60
|
from mindspore._c_expression import ReduceAnyPrim_
|
|
@@ -91,6 +93,15 @@ class _PyboostAdaptiveMaxPool2DPrim(AdaptiveMaxPool2DPrim_):
|
|
|
91
93
|
adaptive_max_pool2d_impl = _PyboostAdaptiveMaxPool2DPrim()
|
|
92
94
|
|
|
93
95
|
|
|
96
|
+
class _PyboostApplyRotaryPosEmbPrim(ApplyRotaryPosEmbPrim_):
|
|
97
|
+
def __call__(self, query, key, cos, sin, position_ids, cos_format):
|
|
98
|
+
|
|
99
|
+
return super().__call__([query, key, cos, sin, position_ids, cos_format])
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
apply_rotary_pos_emb_impl = _PyboostApplyRotaryPosEmbPrim()
|
|
103
|
+
|
|
104
|
+
|
|
94
105
|
class _PyboostArgMaxWithValuePrim(ArgMaxWithValuePrim_):
|
|
95
106
|
def __call__(self, input, axis, keep_dims):
|
|
96
107
|
|
|
@@ -190,6 +201,15 @@ class _PyboostCummaxPrim(CummaxPrim_):
|
|
|
190
201
|
cummax_impl = _PyboostCummaxPrim()
|
|
191
202
|
|
|
192
203
|
|
|
204
|
+
class _PyboostDiagonalViewPrim(DiagonalViewPrim_):
|
|
205
|
+
def __call__(self, input, offset, dim1, dim2):
|
|
206
|
+
|
|
207
|
+
return super().__call__([input, offset, dim1, dim2])
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
diagonal_view_impl = _PyboostDiagonalViewPrim()
|
|
211
|
+
|
|
212
|
+
|
|
193
213
|
class _PyboostEluExtPrim(EluExtPrim_):
|
|
194
214
|
def __call__(self, input, alpha):
|
|
195
215
|
|
|
@@ -440,6 +460,16 @@ class _PyboostOneHotExtPrim(OneHotExtPrim_):
|
|
|
440
460
|
one_hot_ext_impl = _PyboostOneHotExtPrim()
|
|
441
461
|
|
|
442
462
|
|
|
463
|
+
class _PyboostPagedAttentionPrim(PagedAttentionPrim_):
|
|
464
|
+
def __call__(self, query, key_cache, value_cache, block_tables, context_lens, antiquant_scale, antiquant_offset, attn_mask, q_seq_lens, alibi_mask, head_num, scale_value, kv_head_num, kv_cache_quant_mode, mask_mode, mla_v_dim):
|
|
465
|
+
converted_kv_cache_quant_mode = str_to_enum('paged_attention', 'kv_cache_quant_mode', kv_cache_quant_mode)
|
|
466
|
+
converted_mask_mode = str_to_enum('paged_attention', 'mask_mode', mask_mode)
|
|
467
|
+
return super().__call__([query, key_cache, value_cache, block_tables, context_lens, antiquant_scale, antiquant_offset, attn_mask, q_seq_lens, alibi_mask, head_num, scale_value, kv_head_num, converted_kv_cache_quant_mode, converted_mask_mode, mla_v_dim])
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
paged_attention_impl = _PyboostPagedAttentionPrim()
|
|
471
|
+
|
|
472
|
+
|
|
443
473
|
class _PyboostPromptFlashAttentionPrim(PromptFlashAttentionPrim_):
|
|
444
474
|
def __call__(self, query, key, value, attn_mask, actual_seq_lengths, actual_seq_lengths_kv, pse_shift, deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise):
|
|
445
475
|
converted_input_layout = str_to_enum('prompt_flash_attention', 'input_layout', input_layout)
|
|
@@ -25,6 +25,11 @@ from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, Mu
|
|
|
25
25
|
from mindspore.ops.composite.env_ops import env_get
|
|
26
26
|
from mindspore.ops.function.clip_func import clip_by_global_norm
|
|
27
27
|
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
|
28
|
+
from mindspore.ops.composite.multitype_ops.add_impl import augassign_add
|
|
29
|
+
from mindspore.ops.composite.multitype_ops.sub_impl import augassign_sub
|
|
30
|
+
from mindspore.ops.composite.multitype_ops.mul_impl import augassign_mul
|
|
31
|
+
from mindspore.ops.composite.multitype_ops.div_impl import augassign_div
|
|
32
|
+
from mindspore.ops.composite.multitype_ops.floordiv_impl import augassign_floordiv
|
|
28
33
|
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like, _ones_like_for_grad
|
|
29
34
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
30
35
|
from mindspore.ops.function.random_func import normal, laplace, uniform, gamma, poisson, multinomial
|
|
@@ -44,6 +49,11 @@ __all__ = [
|
|
|
44
49
|
'GradOperation',
|
|
45
50
|
'HyperMap',
|
|
46
51
|
'hyper_add',
|
|
52
|
+
'augassign_add',
|
|
53
|
+
'augassign_sub',
|
|
54
|
+
'augassign_mul',
|
|
55
|
+
'augassign_div',
|
|
56
|
+
'augassign_floordiv',
|
|
47
57
|
'zeros_like',
|
|
48
58
|
'ones_like',
|
|
49
59
|
'_ones_like_for_grad',
|
mindspore/ops/composite/base.py
CHANGED
|
@@ -412,14 +412,16 @@ class GradOperation(GradOperation_):
|
|
|
412
412
|
|
|
413
413
|
# check run exclude sens
|
|
414
414
|
if isinstance(fn, (FunctionType, MethodType)):
|
|
415
|
-
if not _pynative_executor.check_run(grad, fn, weights, None, *run_args
|
|
415
|
+
if not _pynative_executor.check_run(grad, fn, weights, None, *run_args,
|
|
416
|
+
create_graph=True):
|
|
416
417
|
_pynative_executor.set_grad_flag(True)
|
|
417
418
|
_pynative_executor.new_graph(fn, *args, **kwargs)
|
|
418
419
|
output = fn(*args, **kwargs)
|
|
419
420
|
_pynative_executor.end_graph(fn, output, *args, **kwargs)
|
|
420
421
|
else:
|
|
421
422
|
# Check if fn has run already
|
|
422
|
-
if not _pynative_executor.check_run(grad, fn, weights, None, *run_args
|
|
423
|
+
if not _pynative_executor.check_run(grad, fn, weights, None, *run_args,
|
|
424
|
+
create_graph=True):
|
|
423
425
|
requires_grad = fn.requires_grad
|
|
424
426
|
fn.requires_grad = True
|
|
425
427
|
fn(*args, **kwargs)
|
|
@@ -662,7 +664,8 @@ class _Grad(GradOperation_):
|
|
|
662
664
|
outputs = ()
|
|
663
665
|
run_forward = False
|
|
664
666
|
if isinstance(fn, (FunctionType, MethodType)):
|
|
665
|
-
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args
|
|
667
|
+
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args,
|
|
668
|
+
create_graph=True):
|
|
666
669
|
_pynative_executor.set_grad_flag(True)
|
|
667
670
|
_pynative_executor.new_graph(fn, *args, **kwargs)
|
|
668
671
|
outputs = fn(*args, **kwargs)
|
|
@@ -670,7 +673,8 @@ class _Grad(GradOperation_):
|
|
|
670
673
|
run_forward = True
|
|
671
674
|
else:
|
|
672
675
|
# Check if fn has run already.
|
|
673
|
-
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args
|
|
676
|
+
if not _pynative_executor.check_run(grad, fn, weights, self.grad_position, *run_args,
|
|
677
|
+
create_graph=True):
|
|
674
678
|
requires_grad = fn.requires_grad
|
|
675
679
|
fn.requires_grad = True
|
|
676
680
|
outputs = fn(*args, **kwargs)
|
|
@@ -45,6 +45,12 @@ from mindspore.ops.composite.multitype_ops.uadd_impl import uadd
|
|
|
45
45
|
from mindspore.ops.composite.multitype_ops.in_impl import in_
|
|
46
46
|
from mindspore.ops.composite.multitype_ops.not_in_impl import not_in_
|
|
47
47
|
from mindspore.ops.composite.multitype_ops.invert_impl import invert
|
|
48
|
+
from mindspore.ops.composite.multitype_ops.add_impl import augassign_add
|
|
49
|
+
from mindspore.ops.composite.multitype_ops.sub_impl import augassign_sub
|
|
50
|
+
from mindspore.ops.composite.multitype_ops.mul_impl import augassign_mul
|
|
51
|
+
from mindspore.ops.composite.multitype_ops.div_impl import augassign_div
|
|
52
|
+
from mindspore.ops.composite.multitype_ops.floordiv_impl import augassign_floordiv
|
|
53
|
+
|
|
48
54
|
__all__ = [
|
|
49
55
|
'add',
|
|
50
56
|
'sub',
|
|
@@ -76,5 +82,10 @@ __all__ = [
|
|
|
76
82
|
'in_',
|
|
77
83
|
'not_in_',
|
|
78
84
|
'invert',
|
|
79
|
-
'_ones_like_for_grad'
|
|
85
|
+
'_ones_like_for_grad',
|
|
86
|
+
'augassign_add',
|
|
87
|
+
'augassign_sub',
|
|
88
|
+
'augassign_mul',
|
|
89
|
+
'augassign_div',
|
|
90
|
+
'augassign_floordiv'
|
|
80
91
|
]
|
|
@@ -34,9 +34,8 @@ from mindspore.common import Tensor, CSRTensor, COOTensor, mutable
|
|
|
34
34
|
from mindspore import ops
|
|
35
35
|
from mindspore.ops.primitive import _primexpr
|
|
36
36
|
from mindspore import _checkparam as validator
|
|
37
|
-
from mindspore.
|
|
38
|
-
|
|
39
|
-
index_op, inplace_index_put_op
|
|
37
|
+
from mindspore.ops.auto_generate.gen_ops_prim import select_ext_view_op, slice_ext_view_op, expand_dims_view_op, \
|
|
38
|
+
inplace_copy_op, index_op, inplace_index_put_op
|
|
40
39
|
|
|
41
40
|
slice_get_item = SliceGetItem()
|
|
42
41
|
hyper_map = base.HyperMap()
|
|
@@ -53,7 +52,7 @@ copy_with_slice = CopyWithSlice()
|
|
|
53
52
|
tensor_1d = Tensor([0], dtype=mstype.int64)
|
|
54
53
|
empty_tensor_1d = Tensor(shape=(0,), dtype=mstype.int64)
|
|
55
54
|
empty_tensor_9d = Tensor(shape=(0,)*9, dtype=mstype.int64)
|
|
56
|
-
|
|
55
|
+
EllipsisType = type(...)
|
|
57
56
|
|
|
58
57
|
|
|
59
58
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
@@ -109,7 +108,7 @@ def data_update(transfer_types, args, data, new_index, value=None):
|
|
|
109
108
|
if transfer_type <= ValueTransferType.kScatterND:
|
|
110
109
|
data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
|
|
111
110
|
if transfer_type == ValueTransferType.kJustReturn:
|
|
112
|
-
return
|
|
111
|
+
return arg
|
|
113
112
|
if transfer_type == ValueTransferType.kSetItemByBool:
|
|
114
113
|
return tensor_setitem_by_bool(data, new_index, value)
|
|
115
114
|
if transfer_type == ValueTransferType.kCopySlice:
|
|
@@ -230,131 +229,127 @@ setattr(tensor_operator_registry, "_tensor_getitem_origin", _tensor_getitem_orig
|
|
|
230
229
|
setattr(tensor_operator_registry, "_tensor_setitem_origin", _tensor_setitem_origin)
|
|
231
230
|
|
|
232
231
|
|
|
233
|
-
def _record_tensor_index(index, remain_indexes, dim):
|
|
234
|
-
"""Record indexes remained to be used by aclnnIndex/aclnnIndexPut"""
|
|
235
|
-
if len(remain_indexes) > dim:
|
|
236
|
-
remain_indexes[dim] = index
|
|
237
|
-
return remain_indexes
|
|
238
|
-
|
|
239
|
-
while dim > len(remain_indexes):
|
|
240
|
-
# use empty_tensor with dim_num 9 to indicate unused dim
|
|
241
|
-
remain_indexes.append(empty_tensor_9d)
|
|
242
|
-
|
|
243
|
-
remain_indexes.append(index)
|
|
244
|
-
return remain_indexes
|
|
245
|
-
|
|
246
|
-
|
|
247
232
|
def _count_indexed_dims(indexes):
|
|
248
233
|
"""Count indexed dims"""
|
|
249
234
|
count = 0
|
|
250
235
|
for index in indexes:
|
|
251
236
|
if isinstance(index, Tensor):
|
|
252
237
|
if index.dtype == mstype.bool_:
|
|
253
|
-
count += index
|
|
238
|
+
count += F.rank(index)
|
|
254
239
|
else:
|
|
255
240
|
count += 1
|
|
256
|
-
elif not isinstance(index, (type(None),
|
|
241
|
+
elif not isinstance(index, (type(None), EllipsisType, bool)):
|
|
257
242
|
count += 1
|
|
258
243
|
return count
|
|
259
244
|
|
|
260
245
|
|
|
261
|
-
def _do_select(self: Tensor, dim: int, index: int, dim_index: int,
|
|
246
|
+
def _do_select(self: Tensor, dim: int, index: int, dim_index: int, dim_size: int):
|
|
262
247
|
"""call select view operator"""
|
|
263
|
-
if not self_shape:
|
|
264
|
-
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
265
|
-
dim_size = self_shape[dim]
|
|
266
248
|
if index >= dim_size or index < -dim_size:
|
|
267
|
-
raise IndexError(
|
|
268
|
-
index = index + dim_size
|
|
249
|
+
raise IndexError("Index is out of bounds.")
|
|
250
|
+
index = (index + dim_size) % dim_size
|
|
269
251
|
return select_ext_view_op(self, dim, index)
|
|
270
252
|
|
|
271
253
|
|
|
272
|
-
def _do_slice(self: Tensor, dim: int, index: slice,
|
|
254
|
+
def _do_slice(self: Tensor, dim: int, index: slice, dim_size: int):
|
|
273
255
|
"""call slice view operator"""
|
|
274
256
|
def _get_index(index, default):
|
|
275
257
|
if index is None:
|
|
276
258
|
return default
|
|
277
259
|
if isinstance(index, Tensor):
|
|
278
|
-
return
|
|
260
|
+
return TensorToScalar()(index)
|
|
279
261
|
return index
|
|
280
262
|
|
|
281
|
-
if not self_shape:
|
|
282
|
-
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
283
263
|
step = _get_index(index.step, 1)
|
|
284
264
|
if step <= 0:
|
|
285
265
|
raise ValueError("slice step must be positive")
|
|
286
266
|
start = _get_index(index.start, 0)
|
|
287
|
-
end = _get_index(index.stop,
|
|
288
|
-
if start == 0 and end ==
|
|
267
|
+
end = _get_index(index.stop, dim_size)
|
|
268
|
+
if start == 0 and end == dim_size and step == 1:
|
|
289
269
|
return self
|
|
290
|
-
return
|
|
270
|
+
return slice_ext_view_op(self, dim, start, end, step)
|
|
291
271
|
|
|
292
272
|
|
|
293
273
|
def _process_dim_in_multi_dim_index(prev_result, orig_tensor, index, dim, indexed_dims, dim_index, remain_indexes,
|
|
294
|
-
|
|
274
|
+
orig_dim, need_index_prim):
|
|
295
275
|
"""Process dim in multi dim index"""
|
|
276
|
+
result = prev_result
|
|
296
277
|
if isinstance(index, bool):
|
|
297
|
-
result =
|
|
278
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
298
279
|
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
299
|
-
|
|
300
|
-
|
|
280
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + (index_for_bool,)
|
|
281
|
+
need_index_prim = True
|
|
301
282
|
dim += 1
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
result = _do_slice(prev_result, dim, index, prev_shape)
|
|
309
|
-
# current dim in prev_shape will not be used later, ignore it
|
|
283
|
+
elif isinstance(index, int):
|
|
284
|
+
result = _do_select(prev_result, dim, index, dim_index, F.shape(orig_tensor)[orig_dim])
|
|
285
|
+
orig_dim += 1
|
|
286
|
+
elif isinstance(index, slice):
|
|
287
|
+
result = _do_slice(prev_result, dim, index, F.shape(orig_tensor)[orig_dim])
|
|
288
|
+
orig_dim += 1
|
|
310
289
|
dim += 1
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
result =
|
|
317
|
-
prev_shape.insert(dim, 1)
|
|
290
|
+
elif isinstance(index, EllipsisType):
|
|
291
|
+
ellipsis_dims = F.rank(orig_tensor) - indexed_dims
|
|
292
|
+
orig_dim += ellipsis_dims
|
|
293
|
+
dim += ellipsis_dims
|
|
294
|
+
elif index is None:
|
|
295
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
318
296
|
dim += 1
|
|
319
|
-
|
|
320
|
-
if isinstance(index, Tensor):
|
|
297
|
+
elif isinstance(index, Tensor):
|
|
321
298
|
result = prev_result
|
|
322
|
-
if index
|
|
299
|
+
if F.rank(index) == 0 and index.dtype in mstype.int_type + mstype.uint_type + (mstype.bool_,):
|
|
323
300
|
if index.dtype in mstype.int_type + mstype.uint_type:
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
301
|
+
index_py = TensorToScalar()(index)
|
|
302
|
+
result = _do_select(prev_result, dim, index_py, dim_index, F.shape(orig_tensor)[orig_dim])
|
|
303
|
+
orig_dim += 1
|
|
304
|
+
# in graph mode, remain_indexes in different branch requires same size, so we fill empty tensor to it
|
|
305
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes) + 1)
|
|
306
|
+
else:
|
|
307
|
+
# process index with Tensor bool type
|
|
308
|
+
result = expand_dims_view_op(prev_result, dim)
|
|
309
|
+
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
310
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + \
|
|
311
|
+
(index_for_bool,)
|
|
312
|
+
need_index_prim = True
|
|
313
|
+
dim += 1
|
|
314
|
+
else:
|
|
315
|
+
remain_indexes = remain_indexes[0:dim] + (empty_tensor_9d,) * (dim - len(remain_indexes)) + (index,)
|
|
316
|
+
need_index_prim = True
|
|
317
|
+
orig_dim += 1
|
|
332
318
|
dim += 1
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
return result, dim, remain_indexes, prev_shape
|
|
337
|
-
raise IndexError(f"Invalid tensor index type {index}")
|
|
319
|
+
else:
|
|
320
|
+
raise IndexError("Invalid tensor index type")
|
|
321
|
+
return result, dim, remain_indexes, orig_dim, need_index_prim
|
|
338
322
|
|
|
339
323
|
|
|
340
324
|
def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
|
|
341
325
|
"""Process indexes in tuple"""
|
|
342
326
|
self_viewed = self
|
|
343
|
-
self_viewed_shape = list(self.shape)
|
|
344
327
|
dim = 0
|
|
328
|
+
orig_dim = 0
|
|
329
|
+
need_index_prim = False
|
|
345
330
|
for i, index in enumerate(indexes):
|
|
346
331
|
if isinstance(index, (list, tuple, np.ndarray)):
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
332
|
+
if not F.isconstant(index):
|
|
333
|
+
raise IndexError(
|
|
334
|
+
"Current Tensor indexing does not support mutable list/tuple or list containing tensors. "
|
|
335
|
+
"Please use an immutable expression instead.")
|
|
336
|
+
index = Tensor(index)
|
|
337
|
+
if isinstance(index, Tensor) and \
|
|
338
|
+
F.dtype(index) in (mstype.int8, mstype.int16, mstype.uint16, mstype.uint32,
|
|
339
|
+
mstype.uint64, mstype.float16, mstype.float32, mstype.float64):
|
|
340
|
+
# only uint8, int32 and int64 are supported by IndexOp
|
|
341
|
+
index = F.cast(index, mstype.int64)
|
|
342
|
+
self_viewed, dim, remain_indexes, orig_dim, need_index_prim = _process_dim_in_multi_dim_index(
|
|
343
|
+
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, orig_dim, need_index_prim)
|
|
344
|
+
return self_viewed, remain_indexes, need_index_prim
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _check_type_of_list_index(index_list):
|
|
348
|
+
"""Check type of element in list index"""
|
|
349
|
+
for index in index_list:
|
|
350
|
+
if isinstance(index, (Tensor, list, tuple, slice, type(None), EllipsisType)):
|
|
351
|
+
return True
|
|
352
|
+
return False
|
|
358
353
|
|
|
359
354
|
|
|
360
355
|
def _wrap_index_to_tuple(index):
|
|
@@ -362,7 +357,7 @@ def _wrap_index_to_tuple(index):
|
|
|
362
357
|
if isinstance(index, tuple):
|
|
363
358
|
return index
|
|
364
359
|
if isinstance(index, list):
|
|
365
|
-
if len(index) < 32 and
|
|
360
|
+
if len(index) < 32 and _check_type_of_list_index(index):
|
|
366
361
|
return tuple(index)
|
|
367
362
|
return (index,)
|
|
368
363
|
|
|
@@ -370,64 +365,92 @@ def _wrap_index_to_tuple(index):
|
|
|
370
365
|
def _tensor_getitem(self, index):
|
|
371
366
|
"""Handle tensor getitem"""
|
|
372
367
|
if isinstance(index, bool):
|
|
373
|
-
self_viewed =
|
|
368
|
+
self_viewed = expand_dims_view_op(self, 0)
|
|
374
369
|
index_for_bool = tensor_1d if index else empty_tensor_1d
|
|
375
370
|
return index_op(self_viewed, [index_for_bool])
|
|
376
371
|
if isinstance(index, int):
|
|
377
|
-
|
|
372
|
+
self_shape = F.shape(self)
|
|
373
|
+
if not self_shape:
|
|
374
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
375
|
+
return _do_select(self, 0, index, 0, self_shape[0])
|
|
378
376
|
if isinstance(index, slice):
|
|
379
|
-
|
|
377
|
+
self_shape = F.shape(self)
|
|
378
|
+
if not self_shape:
|
|
379
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
380
|
+
result = _do_slice(self, 0, index, self_shape[0])
|
|
380
381
|
return result
|
|
381
382
|
if index is None:
|
|
382
|
-
return
|
|
383
|
-
if isinstance(index,
|
|
383
|
+
return expand_dims_view_op(self, 0)
|
|
384
|
+
if isinstance(index, EllipsisType):
|
|
384
385
|
return self
|
|
385
386
|
indexes = _wrap_index_to_tuple(index)
|
|
386
387
|
indexed_dims = _count_indexed_dims(indexes)
|
|
387
|
-
if self
|
|
388
|
-
raise IndexError(
|
|
389
|
-
remain_indexes =
|
|
390
|
-
self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
391
|
-
if not
|
|
388
|
+
if F.rank(self) < indexed_dims:
|
|
389
|
+
raise IndexError("For getitem, there are too many indices.")
|
|
390
|
+
remain_indexes = ()
|
|
391
|
+
self_viewed, remain_indexes, need_index_prim = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
392
|
+
if not need_index_prim:
|
|
392
393
|
return self_viewed
|
|
393
394
|
return index_op(self_viewed, remain_indexes)
|
|
394
395
|
|
|
395
396
|
|
|
397
|
+
def do_copy(dst, src):
|
|
398
|
+
"""do copy"""
|
|
399
|
+
src_shape = F.shape(src)
|
|
400
|
+
dst_shape = F.shape(dst)
|
|
401
|
+
if F.is_sequence_value_unknown(src_shape) or F.is_sequence_value_unknown(dst_shape):
|
|
402
|
+
return inplace_copy_op(dst, src)
|
|
403
|
+
if src_shape == dst_shape or not src_shape:
|
|
404
|
+
return inplace_copy_op(dst, src)
|
|
405
|
+
# remove all leading 1, e.g. (1, 1, 2, 3) -> (2, 3)
|
|
406
|
+
idx = 0
|
|
407
|
+
while idx < len(src_shape) and src_shape[idx] == 1:
|
|
408
|
+
idx += 1
|
|
409
|
+
src_viewed = src.view(src_shape[idx:])
|
|
410
|
+
return inplace_copy_op(dst, src_viewed)
|
|
411
|
+
|
|
412
|
+
|
|
396
413
|
def _tensor_setitem(self, index, value):
|
|
397
414
|
"""Handle tensor setitem"""
|
|
398
415
|
if not isinstance(value, Tensor):
|
|
399
416
|
if isinstance(value, (bool, int, float)):
|
|
400
417
|
value = Tensor(value, dtype=self.dtype)
|
|
401
418
|
else:
|
|
402
|
-
raise TypeError(
|
|
419
|
+
raise TypeError("For __setitem__, the type of value can only be bool, int, float or Tensor.")
|
|
403
420
|
|
|
404
421
|
if isinstance(index, bool) and index is False:
|
|
405
422
|
return self
|
|
406
|
-
if isinstance(index,
|
|
407
|
-
|
|
423
|
+
if isinstance(index, EllipsisType):
|
|
424
|
+
do_copy(self, value)
|
|
408
425
|
return self
|
|
409
426
|
if index is None or (isinstance(index, bool) and index is True):
|
|
410
|
-
self_viewed =
|
|
411
|
-
|
|
427
|
+
self_viewed = expand_dims_view_op(self, 0)
|
|
428
|
+
do_copy(self_viewed, value)
|
|
412
429
|
return self
|
|
413
430
|
if isinstance(index, int):
|
|
414
|
-
|
|
415
|
-
|
|
431
|
+
self_shape = F.shape(self)
|
|
432
|
+
if not self_shape:
|
|
433
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
434
|
+
self_viewed = _do_select(self, 0, index, 0, self_shape[0])
|
|
435
|
+
do_copy(self_viewed, value)
|
|
416
436
|
return self
|
|
417
437
|
if isinstance(index, slice):
|
|
418
|
-
|
|
419
|
-
|
|
438
|
+
self_shape = F.shape(self)
|
|
439
|
+
if not self_shape:
|
|
440
|
+
raise TypeError("Invalid index of a 0-dim tensor.")
|
|
441
|
+
self_viewed = _do_slice(self, 0, index, self_shape[0])
|
|
442
|
+
do_copy(self_viewed, value)
|
|
420
443
|
return self
|
|
421
444
|
indexes = _wrap_index_to_tuple(index)
|
|
422
445
|
indexed_dims = _count_indexed_dims(indexes)
|
|
423
|
-
if self
|
|
424
|
-
raise IndexError(
|
|
425
|
-
remain_indexes =
|
|
426
|
-
self_viewed, remain_indexes = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
427
|
-
if not
|
|
428
|
-
|
|
446
|
+
if F.rank(self) < indexed_dims:
|
|
447
|
+
raise IndexError("For setitem, there are too many indices")
|
|
448
|
+
remain_indexes = ()
|
|
449
|
+
self_viewed, remain_indexes, need_index_prim = _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims)
|
|
450
|
+
if not need_index_prim:
|
|
451
|
+
do_copy(self_viewed, value)
|
|
429
452
|
return self
|
|
430
|
-
inplace_index_put_op(self_viewed, remain_indexes, value)
|
|
453
|
+
inplace_index_put_op(self_viewed, remain_indexes, value, False)
|
|
431
454
|
return self
|
|
432
455
|
|
|
433
456
|
|
|
@@ -829,6 +852,7 @@ def _tensor_index_by_integer(data, int_index):
|
|
|
829
852
|
end_mask += 2 ** i
|
|
830
853
|
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
831
854
|
|
|
855
|
+
|
|
832
856
|
def _check_dim_shape_valid(data, tensor_index):
|
|
833
857
|
"""check dim and shape of tensor_index for tensor(bool) indexing"""
|
|
834
858
|
if data.ndim < tensor_index.ndim:
|