mindspore 2.4.1__cp311-cp311-win_amd64.whl → 2.5.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +99 -25
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +30 -27
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +10 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +292 -193
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +28 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +125 -65
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +3 -2
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
- mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
- mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +1 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +142 -21
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/conv.py +3 -0
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +55 -23
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/__init__.py +22 -0
- mindspore/nn/utils/init.py +73 -0
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1664 -294
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +23 -17
- mindspore/ops/operations/custom_ops.py +7 -49
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +153 -310
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +44 -34
- mindspore/parallel/_cell_wrapper.py +22 -3
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +18 -13
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +37 -15
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -14
- mindspore/train/serialization.py +461 -72
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +391 -265
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
mindspore/numpy/utils.py
CHANGED
|
@@ -19,8 +19,8 @@ import types
|
|
|
19
19
|
|
|
20
20
|
from mindspore.common import Tensor
|
|
21
21
|
from mindspore._c_expression import Tensor as Tensor_
|
|
22
|
-
from mindspore.ops import functional as F
|
|
23
22
|
from mindspore.common import dtype as mstype
|
|
23
|
+
from mindspore import ops
|
|
24
24
|
|
|
25
25
|
from mindspore.numpy.utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert, \
|
|
26
26
|
_tuple_setitem, _callable_const, _check_is_float, _get_device
|
|
@@ -65,7 +65,7 @@ def _check_input_for_asarray(array_like):
|
|
|
65
65
|
|
|
66
66
|
def _is_scalar(shape):
|
|
67
67
|
"""check whether input shape is a scalar"""
|
|
68
|
-
return
|
|
68
|
+
return ops.shape_mul(shape) == 1
|
|
69
69
|
|
|
70
70
|
|
|
71
71
|
def _convert_list_tensor_to_tuple_tensor(list_of_tensor):
|
|
@@ -80,27 +80,27 @@ def _convert_list_tensor_to_tuple_tensor(list_of_tensor):
|
|
|
80
80
|
|
|
81
81
|
def _expand(x, ndim, axis=0):
|
|
82
82
|
"""Expand x to ndim from axis, which can be 0 or -1."""
|
|
83
|
-
shape = _add_unit_axes(
|
|
84
|
-
return
|
|
83
|
+
shape = _add_unit_axes(ops.shape(x), ndim, axis == -1)
|
|
84
|
+
return ops.reshape(x, shape)
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
|
|
88
88
|
"""Broadcasts x from shape_cur to shape_to."""
|
|
89
89
|
size = _tile_size(shape_cur, shape_to, ndim_to)
|
|
90
|
-
return
|
|
90
|
+
return ops.tile(x, size)
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def _broadcast_to_shape(x, shape):
|
|
94
94
|
"""Broadcasts x from current shape to shape"""
|
|
95
95
|
ndim_to = len(shape)
|
|
96
96
|
x = _expand(x, ndim_to)
|
|
97
|
-
return _broadcast_to(x,
|
|
97
|
+
return _broadcast_to(x, ops.shape(x), shape, ndim_to)
|
|
98
98
|
|
|
99
99
|
|
|
100
100
|
def _get_size(x, axis=None):
|
|
101
101
|
"""Get the number of elements along the given axis of tensor x."""
|
|
102
|
-
if axis is None or
|
|
103
|
-
axis =
|
|
102
|
+
if axis is None or ops.tuple_len(axis) == 0:
|
|
103
|
+
axis = ops.make_range(x.ndim)
|
|
104
104
|
nums = 1
|
|
105
105
|
for ax in axis:
|
|
106
106
|
nums *= x.shape[ax]
|
|
@@ -110,7 +110,7 @@ def _get_size(x, axis=None):
|
|
|
110
110
|
def _check_input_tensor(*tensors):
|
|
111
111
|
for tensor in tensors:
|
|
112
112
|
if not isinstance(tensor, Tensor):
|
|
113
|
-
_raise_type_error('expect Tensor, but got ',
|
|
113
|
+
_raise_type_error('expect Tensor, but got ', ops.typeof(tensor))
|
|
114
114
|
return True
|
|
115
115
|
|
|
116
116
|
|
|
@@ -141,7 +141,7 @@ def _to_tensor(*args):
|
|
|
141
141
|
|
|
142
142
|
def _get_dtype_from_scalar(*input_numbers):
|
|
143
143
|
"""
|
|
144
|
-
Get the final dtype from series of input numbers, compared with
|
|
144
|
+
Get the final dtype from series of input numbers, compared with ops.typeof, we
|
|
145
145
|
return int32/float32 for python int/float instead.
|
|
146
146
|
"""
|
|
147
147
|
bool_flag = True
|
|
@@ -184,7 +184,7 @@ def _slice_along_axis(f, axis, slice_start, slice_end):
|
|
|
184
184
|
slice_size = slice_end - slice_start
|
|
185
185
|
index_start = _tuple_setitem(index_start, axis, slice_start)
|
|
186
186
|
index_end = _tuple_setitem(index_end, axis, slice_size)
|
|
187
|
-
return
|
|
187
|
+
return ops.tensor_slice(f, index_start, index_end)
|
|
188
188
|
|
|
189
189
|
|
|
190
190
|
def _to_tensor_origin_dtype(*args):
|
|
@@ -203,12 +203,12 @@ def _to_tensor_origin_dtype(*args):
|
|
|
203
203
|
|
|
204
204
|
def _callable(tensor, obj):
|
|
205
205
|
"""Returns True if `obj` is a function."""
|
|
206
|
-
if
|
|
206
|
+
if ops.isconstant(tensor):
|
|
207
207
|
return isinstance(obj, types.FunctionType)
|
|
208
|
-
return _callable_const(
|
|
208
|
+
return _callable_const(ops.typeof(obj))
|
|
209
209
|
|
|
210
210
|
|
|
211
211
|
def _isnan(x):
|
|
212
|
-
if _get_device() == 'Ascend' and not _check_is_float(
|
|
213
|
-
return
|
|
214
|
-
return
|
|
212
|
+
if _get_device() == 'Ascend' and not _check_is_float(ops.dtype(x)):
|
|
213
|
+
return ops.fill(mstype.bool_, ops.shape(x), False)
|
|
214
|
+
return ops.isnan(x)
|
mindspore/numpy/utils_const.py
CHANGED
|
@@ -20,7 +20,6 @@ from itertools import accumulate
|
|
|
20
20
|
import operator
|
|
21
21
|
|
|
22
22
|
import mindspore.context as context
|
|
23
|
-
from mindspore.ops import functional as F
|
|
24
23
|
from mindspore.ops.primitive import constexpr
|
|
25
24
|
from mindspore.ops.primitive import _primexpr
|
|
26
25
|
from mindspore.common import dtype as mstype
|
|
@@ -28,6 +27,7 @@ from mindspore.common import Tensor
|
|
|
28
27
|
from mindspore._c_expression import Tensor as Tensor_
|
|
29
28
|
from mindspore._c_expression import typing
|
|
30
29
|
from mindspore import _checkparam as validator
|
|
30
|
+
from mindspore import ops
|
|
31
31
|
|
|
32
32
|
from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
|
|
33
33
|
|
|
@@ -80,13 +80,13 @@ def _check_dtype(dtype):
|
|
|
80
80
|
@_primexpr
|
|
81
81
|
def _is_shape_empty(shp):
|
|
82
82
|
"""Check whether shape contains zero"""
|
|
83
|
-
if
|
|
83
|
+
if ops.is_sequence_shape_unknown(shp):
|
|
84
84
|
return False
|
|
85
85
|
if isinstance(shp, int):
|
|
86
86
|
return shp == 0
|
|
87
87
|
if isinstance(shp, (tuple, list)):
|
|
88
88
|
return 0 in shp
|
|
89
|
-
return
|
|
89
|
+
return ops.shape_mul(shp) == 0
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
@_primexpr
|
|
@@ -189,7 +189,7 @@ def _check_axis_valid(axes, ndim):
|
|
|
189
189
|
raise ValueError('duplicate value in "axis"')
|
|
190
190
|
|
|
191
191
|
if axes is None:
|
|
192
|
-
axes =
|
|
192
|
+
axes = ops.make_range(ndim)
|
|
193
193
|
return axes
|
|
194
194
|
if isinstance(axes, (tuple, list)):
|
|
195
195
|
axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
|
mindspore/opencv_core452.dll
CHANGED
|
Binary file
|
|
Binary file
|
mindspore/opencv_imgproc452.dll
CHANGED
|
Binary file
|
mindspore/ops/__init__.py
CHANGED
|
@@ -33,6 +33,7 @@ from mindspore.ops import composite, operations, functional, function
|
|
|
33
33
|
from mindspore.ops import signature
|
|
34
34
|
from mindspore.ops.auto_generate import cpp_create_prim_instance_helper, gen_arg_dtype_cast, gen_arg_handler, \
|
|
35
35
|
gen_extend_func, gen_ops_def, gen_ops_prim, pyboost_inner_prim
|
|
36
|
+
from mindspore.ops.functional_overload import all_gather_matmul, matmul_reduce_scatter
|
|
36
37
|
from mindspore.ops.composite import *
|
|
37
38
|
from mindspore.ops.operations import *
|
|
38
39
|
from mindspore.ops.function import *
|
|
@@ -47,7 +48,7 @@ __all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
|
|
47
48
|
"CpuRegOp", "CustomRegOp", "DataType",
|
|
48
49
|
"constexpr", "reshard",
|
|
49
50
|
"cpp_create_prim_instance_helper", "gen_arg_dtype_cast", "gen_arg_handler", "gen_extend_func", "gen_ops_def",
|
|
50
|
-
"gen_ops_prim", "pyboost_inner_prim"]
|
|
51
|
+
"gen_ops_prim", "pyboost_inner_prim", "all_gather_matmul", "matmul_reduce_scatter"]
|
|
51
52
|
__all__.extend(__primitive__)
|
|
52
53
|
__all__.extend(composite.__all__)
|
|
53
54
|
__all__.extend(operations.__all__)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
"""Generate bprop for comm ops"""
|
|
17
17
|
from __future__ import division
|
|
18
18
|
from __future__ import absolute_import
|
|
19
|
-
from mindspore import Tensor
|
|
19
|
+
from mindspore import Tensor, Parameter
|
|
20
20
|
import mindspore.common.dtype as mstype
|
|
21
21
|
from mindspore.ops import functional as F
|
|
22
22
|
from mindspore.communication import get_rank, get_group_size
|
|
@@ -37,6 +37,24 @@ from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
|
37
37
|
from mindspore.ops.operations import _grad_ops as G
|
|
38
38
|
import mindspore as ms
|
|
39
39
|
|
|
40
|
+
_squared_device_local_norm = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_squared_device_local_norm_param():
|
|
44
|
+
"""
|
|
45
|
+
Get Parameter `_squared_device_local_norm`.
|
|
46
|
+
`_squared_device_local_norm` will accumulate squared local norm of each grad in bprop under GRAPH_MODE.
|
|
47
|
+
User need to reset it to zero after network propagation each step.
|
|
48
|
+
"""
|
|
49
|
+
global _squared_device_local_norm
|
|
50
|
+
if _squared_device_local_norm is None:
|
|
51
|
+
if ms.get_auto_parallel_context("dump_device_local_norm"):
|
|
52
|
+
_squared_device_local_norm = Parameter(Tensor(0.0, mstype.float32), name="_squared_device_local_norm",
|
|
53
|
+
requires_grad=False)
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError("The parallel config 'dump_device_local_norm' is False.")
|
|
56
|
+
return _squared_device_local_norm
|
|
57
|
+
|
|
40
58
|
|
|
41
59
|
@bprop_getters.register(AllReduce)
|
|
42
60
|
def get_bprop_all_reduce(self):
|
|
@@ -192,7 +210,7 @@ def get_bprop_virtual_assign_kv_cache(self):
|
|
|
192
210
|
dout_update = dout + y
|
|
193
211
|
kv_equal = F.equal(seq_chunk, 0)
|
|
194
212
|
update_kv = F.select(kv_equal, F.broadcast_to(cast(out_tensor, dtype(y)), F.shape(y)), dout_update)
|
|
195
|
-
return F.depend((dout_update, cast(out_tensor, dtype(y)),
|
|
213
|
+
return F.depend((cast(dout_update, dtype(dout)), cast(out_tensor, dtype(y)),
|
|
196
214
|
cast(out_tensor, dtype(seq_chunk))), assign(y, update_kv))
|
|
197
215
|
|
|
198
216
|
return bprop
|
|
@@ -218,6 +236,7 @@ def get_bprop_mirror_micro_step_operator(self):
|
|
|
218
236
|
allgather for sparse feature.
|
|
219
237
|
"""
|
|
220
238
|
group = self.group
|
|
239
|
+
global_rank = get_rank()
|
|
221
240
|
dev_num = self.dev_num
|
|
222
241
|
mean_flag = self.mean_flag
|
|
223
242
|
param_name = " "
|
|
@@ -244,13 +263,29 @@ def get_bprop_mirror_micro_step_operator(self):
|
|
|
244
263
|
out_tensor = Tensor(1.0, mstype.float16)
|
|
245
264
|
opt_shard = _get_enable_parallel_optimizer()
|
|
246
265
|
ln_print = P.Print()
|
|
266
|
+
tensor_dump = P.TensorDump()
|
|
247
267
|
reduce_sum = P.ReduceSum(keep_dims=False)
|
|
248
268
|
square = P.Square()
|
|
269
|
+
sqrt = P.Sqrt()
|
|
249
270
|
dump_local_norm = ms.get_auto_parallel_context("dump_local_norm")
|
|
271
|
+
dump_local_norm_path = ms.get_auto_parallel_context("dump_local_norm_path")
|
|
272
|
+
dump_device_local_norm = ms.get_auto_parallel_context("dump_device_local_norm")
|
|
273
|
+
if dump_device_local_norm:
|
|
274
|
+
# init _squared _squared_device_local_norm
|
|
275
|
+
squared_device_local_norm = get_squared_device_local_norm_param()
|
|
250
276
|
|
|
251
277
|
def bprop(x, z, out, dout):
|
|
252
|
-
if dump_local_norm:
|
|
253
|
-
|
|
278
|
+
if dump_local_norm or dump_device_local_norm:
|
|
279
|
+
squared_norm = reduce_sum(square((z)))
|
|
280
|
+
if dump_local_norm:
|
|
281
|
+
if dump_local_norm_path:
|
|
282
|
+
z = F.depend(z, tensor_dump(dump_local_norm_path + "/rank_" + str(global_rank) +
|
|
283
|
+
"/local_norm__" + param_name, sqrt(squared_norm)))
|
|
284
|
+
else:
|
|
285
|
+
z = F.depend(z, ln_print("dump local norm: ", param_name, sqrt(squared_norm)))
|
|
286
|
+
if dump_device_local_norm:
|
|
287
|
+
z = F.depend(z, F.assign_add(squared_device_local_norm,
|
|
288
|
+
cast(squared_norm, squared_device_local_norm.dtype)))
|
|
254
289
|
real_grad = z
|
|
255
290
|
assign_out = dout
|
|
256
291
|
if issubclass_(F.typeof(dout), mstype.tensor_type):
|
|
@@ -293,8 +328,38 @@ def get_bprop_all_gather(self):
|
|
|
293
328
|
if self.rank_size == 0:
|
|
294
329
|
raise ValueError(f"The 'rank_size' can not be zero, but got {self.rank_size}.")
|
|
295
330
|
scale = 1.0 / self.rank_size
|
|
331
|
+
param_name = ""
|
|
332
|
+
if 'mirror_user_id' in self.get_attr_dict():
|
|
333
|
+
param_name = self.get_attr_dict()['mirror_user_id']
|
|
334
|
+
# monitor local norm
|
|
335
|
+
dump_local_norm = ms.get_auto_parallel_context("dump_local_norm")
|
|
336
|
+
dump_local_norm_path = ms.get_auto_parallel_context("dump_local_norm_path")
|
|
337
|
+
dump_device_local_norm = ms.get_auto_parallel_context("dump_device_local_norm")
|
|
338
|
+
if param_name and (dump_local_norm or dump_device_local_norm):
|
|
339
|
+
global_rank = get_rank()
|
|
340
|
+
cast = P.Cast()
|
|
341
|
+
ln_print = P.Print()
|
|
342
|
+
tensor_dump = P.TensorDump()
|
|
343
|
+
reduce_sum = P.ReduceSum(keep_dims=False)
|
|
344
|
+
square = P.Square()
|
|
345
|
+
sqrt = P.Sqrt()
|
|
346
|
+
if dump_device_local_norm:
|
|
347
|
+
# init _squared _squared_device_local_norm
|
|
348
|
+
squared_device_local_norm = get_squared_device_local_norm_param()
|
|
296
349
|
|
|
297
350
|
def bprop(x, out, dout):
|
|
351
|
+
if param_name and (dump_local_norm or dump_device_local_norm):
|
|
352
|
+
squared_norm = reduce_sum(square((dout)))
|
|
353
|
+
if dump_local_norm:
|
|
354
|
+
if dump_local_norm_path:
|
|
355
|
+
dout = F.depend(dout, tensor_dump(dump_local_norm_path + "/rank_" + str(global_rank) +
|
|
356
|
+
"/local_norm__" + param_name, sqrt(squared_norm)))
|
|
357
|
+
else:
|
|
358
|
+
dout = F.depend(dout, ln_print("dump local norm: ", param_name, sqrt(squared_norm)))
|
|
359
|
+
if dump_device_local_norm:
|
|
360
|
+
dout = F.depend(dout, F.assign_add(squared_device_local_norm,
|
|
361
|
+
cast(squared_norm, squared_device_local_norm.dtype)))
|
|
362
|
+
|
|
298
363
|
dx = reduce_scatter(dout)
|
|
299
364
|
if mean_flag:
|
|
300
365
|
dx = F.tensor_mul(dx, scale)
|
|
@@ -365,14 +430,22 @@ def get_bprop_micro_step_all_gather(self):
|
|
|
365
430
|
if self.instance_name:
|
|
366
431
|
instance_name = "grad_" + self.instance_name
|
|
367
432
|
reduce_scatter.set_prim_instance_name(instance_name)
|
|
433
|
+
global_rank = get_rank()
|
|
368
434
|
cast = P.Cast()
|
|
369
435
|
dtype = P.DType()
|
|
370
436
|
out_tensor = Tensor(1.0, mstype.float16)
|
|
371
437
|
with_mirror_operator = self.get_attr_dict()["with_mirror_operator"]
|
|
372
438
|
ln_print = P.Print()
|
|
439
|
+
tensor_dump = P.TensorDump()
|
|
373
440
|
reduce_sum = P.ReduceSum(keep_dims=False)
|
|
374
441
|
square = P.Square()
|
|
442
|
+
sqrt = P.Sqrt()
|
|
375
443
|
dump_local_norm = ms.get_auto_parallel_context("dump_local_norm")
|
|
444
|
+
dump_local_norm_path = ms.get_auto_parallel_context("dump_local_norm_path")
|
|
445
|
+
dump_device_local_norm = ms.get_auto_parallel_context("dump_device_local_norm")
|
|
446
|
+
if dump_device_local_norm:
|
|
447
|
+
# init _squared _squared_device_local_norm
|
|
448
|
+
squared_device_local_norm = get_squared_device_local_norm_param()
|
|
376
449
|
|
|
377
450
|
def bprop(x, z, out, dout):
|
|
378
451
|
if with_mirror_operator:
|
|
@@ -383,8 +456,17 @@ def get_bprop_micro_step_all_gather(self):
|
|
|
383
456
|
real_grad = F.tensor_mul(real_grad, scale)
|
|
384
457
|
return (real_grad, cast(out_tensor, dtype(z)))
|
|
385
458
|
z = F.depend(z, dout)
|
|
386
|
-
if dump_local_norm:
|
|
387
|
-
|
|
459
|
+
if dump_local_norm or dump_device_local_norm:
|
|
460
|
+
squared_norm = reduce_sum(square((z)))
|
|
461
|
+
if dump_local_norm:
|
|
462
|
+
if dump_local_norm_path:
|
|
463
|
+
z = F.depend(z, tensor_dump(dump_local_norm_path + "/rank_" + str(global_rank) +
|
|
464
|
+
"/local_norm__" + param_name, sqrt(squared_norm)))
|
|
465
|
+
else:
|
|
466
|
+
z = F.depend(z, ln_print("dump local norm: ", param_name, sqrt(squared_norm)))
|
|
467
|
+
if dump_device_local_norm:
|
|
468
|
+
z = F.depend(z, F.assign_add(squared_device_local_norm,
|
|
469
|
+
cast(squared_norm, squared_device_local_norm.dtype)))
|
|
388
470
|
if not do_mirror:
|
|
389
471
|
return (z, cast(out_tensor, dtype(z)))
|
|
390
472
|
real_grad = reduce_scatter(z)
|
|
@@ -586,15 +668,23 @@ def get_bprop_mirror_operator(self):
|
|
|
586
668
|
|
|
587
669
|
dev_num_r = 1.0
|
|
588
670
|
dump_local_norm = ms.get_auto_parallel_context("dump_local_norm")
|
|
671
|
+
dump_local_norm_path = ms.get_auto_parallel_context("dump_local_norm_path")
|
|
672
|
+
dump_device_local_norm = ms.get_auto_parallel_context("dump_device_local_norm")
|
|
673
|
+
if dump_device_local_norm:
|
|
674
|
+
# init _squared _squared_device_local_norm
|
|
675
|
+
squared_device_local_norm = get_squared_device_local_norm_param()
|
|
589
676
|
if dev_num > 1:
|
|
677
|
+
global_rank = get_rank()
|
|
590
678
|
dev_num_r = 1.0 / dev_num
|
|
591
679
|
all_reduce = AllReduce(group=group)
|
|
592
680
|
all_gather = AllGather(group=group)
|
|
593
681
|
mul = P.Mul()
|
|
594
682
|
cast = P.Cast()
|
|
595
683
|
ln_print = P.Print()
|
|
684
|
+
tensor_dump = P.TensorDump()
|
|
596
685
|
reduce_sum = P.ReduceSum(keep_dims=False)
|
|
597
686
|
square = P.Square()
|
|
687
|
+
sqrt = P.Sqrt()
|
|
598
688
|
|
|
599
689
|
fusion = self.get_attr_dict()["fusion"]
|
|
600
690
|
all_reduce.add_prim_attr("fusion", fusion)
|
|
@@ -608,8 +698,17 @@ def get_bprop_mirror_operator(self):
|
|
|
608
698
|
all_reduce.set_prim_instance_name(instance_name)
|
|
609
699
|
|
|
610
700
|
def bprop(x, out, dout):
|
|
611
|
-
if dump_local_norm:
|
|
612
|
-
|
|
701
|
+
if dump_local_norm or dump_device_local_norm:
|
|
702
|
+
squared_norm = reduce_sum(square((dout)))
|
|
703
|
+
if dump_local_norm:
|
|
704
|
+
if dump_local_norm_path:
|
|
705
|
+
dout = F.depend(dout, tensor_dump(dump_local_norm_path + "/rank_" + str(global_rank) +
|
|
706
|
+
"/local_norm__" + param_name, sqrt(squared_norm)))
|
|
707
|
+
else:
|
|
708
|
+
dout = F.depend(dout, ln_print("dump local norm: ", param_name, sqrt(squared_norm)))
|
|
709
|
+
if dump_device_local_norm:
|
|
710
|
+
dout = F.depend(dout, F.assign_add(squared_device_local_norm,
|
|
711
|
+
cast(squared_norm, squared_device_local_norm.dtype)))
|
|
613
712
|
|
|
614
713
|
if dev_num == 1:
|
|
615
714
|
return (dout,)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
"""Generate bprop for debug ops"""
|
|
17
17
|
|
|
18
|
+
import mindspore.ops.functional as F
|
|
18
19
|
from mindspore.ops import operations as P
|
|
19
20
|
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
20
21
|
|
|
@@ -27,5 +28,9 @@ def get_bprop_insert_gradient_of(self):
|
|
|
27
28
|
f = self.f
|
|
28
29
|
|
|
29
30
|
def bprop(x, out, dout):
|
|
30
|
-
|
|
31
|
+
fdout = f(dout)
|
|
32
|
+
if fdout is None:
|
|
33
|
+
dout = F.depend(dout, fdout)
|
|
34
|
+
return (dout,)
|
|
35
|
+
return (fdout,)
|
|
31
36
|
return bprop
|
|
@@ -23,6 +23,15 @@ from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
|
23
23
|
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
24
24
|
|
|
25
25
|
|
|
26
|
+
@bprop_getters.register("raise")
|
|
27
|
+
def get_bprop_raise(self):
|
|
28
|
+
"""Grad definition for `raise` operation."""
|
|
29
|
+
def bprop(x, y, z, out, dout):
|
|
30
|
+
return x, y, z
|
|
31
|
+
|
|
32
|
+
return bprop
|
|
33
|
+
|
|
34
|
+
|
|
26
35
|
@bprop_getters.register(inner.ParallelResizeBilinear)
|
|
27
36
|
def get_bprop_parallel_resize_bilinear(self):
|
|
28
37
|
"""Grad definition for `ParallelResizeBilinear` operation."""
|
|
@@ -657,7 +657,8 @@ def get_bprop_fft_with_size(self):
|
|
|
657
657
|
dx = rfft_fn(dout)
|
|
658
658
|
dx = reverse_branch(dx, onesided, dout_shape, offset_shape,
|
|
659
659
|
output_type, dout, norm, inverse, signal_ndim, offset_size)
|
|
660
|
-
return
|
|
660
|
+
return dx, zeros_like(signal_ndim), zeros_like(inverse), zeros_like(real), zeros_like(norm_enum), \
|
|
661
|
+
zeros_like(onesided), zeros_like(signal_sizes)
|
|
661
662
|
|
|
662
663
|
return bprop
|
|
663
664
|
|
|
@@ -71,6 +71,7 @@ from .pyexecute import _pyexecute_cpu
|
|
|
71
71
|
from .pyfunc import _pyfunc_cpu
|
|
72
72
|
from .buffer_append import _buffer_append_cpu
|
|
73
73
|
from .buffer_get import _buffer_get_cpu
|
|
74
|
+
from .raise_op import _raise_cpu
|
|
74
75
|
from .buffer_sample import _buffer_sample_cpu
|
|
75
76
|
from .priority_replay_buffer import _prb_push_op_cpu
|
|
76
77
|
from .priority_replay_buffer import _prb_sample_op_cpu
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""storeattrgrad op"""
|
|
16
|
+
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
|
17
|
+
|
|
18
|
+
raise_op_info = CpuRegOp("raise") \
|
|
19
|
+
.input(0, "x", "dynamic") \
|
|
20
|
+
.output(0, "y", "dynamic") \
|
|
21
|
+
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
|
22
|
+
.get_op_info()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@op_info_register(raise_op_info)
|
|
26
|
+
def _raise_cpu():
|
|
27
|
+
"""_getattrgrad_cpu cpu register"""
|
|
28
|
+
return
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
"""array_ops vmap impl."""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
|
+
from enum import Enum
|
|
18
19
|
|
|
19
20
|
import mindspore
|
|
20
21
|
import mindspore.numpy as mnp
|
|
@@ -1488,16 +1489,19 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1488
1489
|
"""VmapRule for `P.Meshgrid` operation."""
|
|
1489
1490
|
if isinstance(prim, str):
|
|
1490
1491
|
prim = Primitive(prim)
|
|
1491
|
-
indexing = prim.indexing
|
|
1492
1492
|
|
|
1493
|
-
|
|
1494
|
-
|
|
1493
|
+
class Indexing(Enum):
|
|
1494
|
+
ij = 0
|
|
1495
|
+
xy = 1
|
|
1496
|
+
|
|
1497
|
+
def vmap_rule(inputs_bdim, indexing_bdim):
|
|
1498
|
+
is_all_none, result = vmap_general_preprocess(prim, inputs_bdim, indexing_bdim)
|
|
1495
1499
|
if is_all_none:
|
|
1496
1500
|
return result
|
|
1497
1501
|
|
|
1498
1502
|
if not isinstance(inputs_bdim, (tuple)):
|
|
1499
1503
|
_raise_value_error("The inputs of P.Meshgrid is not tuple.")
|
|
1500
|
-
args = inputs_bdim
|
|
1504
|
+
args = inputs_bdim
|
|
1501
1505
|
if len(args) <= 1:
|
|
1502
1506
|
_raise_value_error(
|
|
1503
1507
|
"The input number of P.Meshgrid must be greater than 1.")
|
|
@@ -1518,7 +1522,9 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1518
1522
|
output_shape.insert(0, axis_size)
|
|
1519
1523
|
ones_shape.insert(0, axis_size)
|
|
1520
1524
|
|
|
1521
|
-
|
|
1525
|
+
indexing, _ = indexing_bdim
|
|
1526
|
+
|
|
1527
|
+
if indexing == Indexing.xy.value:
|
|
1522
1528
|
output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
|
|
1523
1529
|
shape = tuple(output_shape)
|
|
1524
1530
|
|
|
@@ -1531,7 +1537,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1531
1537
|
for each_arg in args:
|
|
1532
1538
|
x, bdim = each_arg
|
|
1533
1539
|
x = _bdim_at_front(x, bdim, axis_size)
|
|
1534
|
-
shape_index = (1 - index) if (index <= 1 and indexing ==
|
|
1540
|
+
shape_index = (1 - index) if (index <= 1 and indexing == Indexing.xy.value) else index
|
|
1535
1541
|
ones_shape[shape_index + 1] = output_shape[shape_index + 1]
|
|
1536
1542
|
x = P.Reshape()(x, tuple(ones_shape))
|
|
1537
1543
|
output = P.Mul()(x, ones_tensor)
|
|
@@ -1889,10 +1895,6 @@ def get_slice_vmap_rule(prim, axis_size):
|
|
|
1889
1895
|
@vmap_rules_getters.register(P.Squeeze)
|
|
1890
1896
|
def get_squeeze_vmap_rule(prim, axis_size):
|
|
1891
1897
|
"""VmapRule for `Squeeze`."""
|
|
1892
|
-
if hasattr(prim, 'axis'):
|
|
1893
|
-
prim_axis = prim.axis
|
|
1894
|
-
else:
|
|
1895
|
-
prim_axis = None
|
|
1896
1898
|
|
|
1897
1899
|
@_primexpr
|
|
1898
1900
|
def move_axis(axes):
|
|
@@ -1911,27 +1913,26 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1911
1913
|
new_axis += (i,)
|
|
1912
1914
|
return new_axis
|
|
1913
1915
|
|
|
1914
|
-
def vmap_rule(x_bdim):
|
|
1915
|
-
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
1916
|
+
def vmap_rule(x_bdim, axis_bdim):
|
|
1917
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim)
|
|
1916
1918
|
if is_all_none:
|
|
1917
1919
|
return result
|
|
1918
1920
|
|
|
1919
1921
|
x, x_dim = x_bdim
|
|
1922
|
+
axis, _ = axis_bdim
|
|
1920
1923
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
1921
1924
|
|
|
1922
|
-
if
|
|
1925
|
+
if axis is None:
|
|
1923
1926
|
if axis_size == 1:
|
|
1924
1927
|
new_axis = generate_all_axis_except_first(F.rank(x))
|
|
1925
|
-
|
|
1926
|
-
out = batch_squeeze(x)
|
|
1928
|
+
out = prim(x, new_axis)
|
|
1927
1929
|
return out, 0
|
|
1928
1930
|
|
|
1929
|
-
out = prim(x)
|
|
1931
|
+
out = prim(x, axis)
|
|
1930
1932
|
return out, 0
|
|
1931
1933
|
|
|
1932
|
-
new_axis = move_axis(
|
|
1933
|
-
|
|
1934
|
-
out = batch_squeeze(x)
|
|
1934
|
+
new_axis = move_axis(axis)
|
|
1935
|
+
out = prim(x, new_axis)
|
|
1935
1936
|
return out, 0
|
|
1936
1937
|
|
|
1937
1938
|
return vmap_rule
|
mindspore/ops/_vmap/vmap_base.py
CHANGED
|
@@ -512,8 +512,6 @@ _ops_vmap_clone_prim_dict = {
|
|
|
512
512
|
"ApplyAdagradV2": P.ApplyAdagradV2,
|
|
513
513
|
"UniformCandidateSampler": UniformCandidateSampler,
|
|
514
514
|
"UniqueWithPad": P.UniqueWithPad,
|
|
515
|
-
"CdistGrad": G.CdistGrad,
|
|
516
|
-
"Cdist": P.Cdist,
|
|
517
515
|
"STFT": math_ops.STFT,
|
|
518
516
|
"Conv2D": P.Conv2D,
|
|
519
517
|
"Conv3D": P.Conv3D,
|
|
@@ -26,6 +26,7 @@ from mindspore.ops.function import _VmapGeneralRule
|
|
|
26
26
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
|
|
27
27
|
_bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
|
|
28
28
|
from mindspore.ops.auto_generate.gen_arg_handler import Format, Reduction
|
|
29
|
+
from mindspore.ops import auto_generate as gen
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
@vmap_rules_getters.register(G.NLLLossGrad)
|
|
@@ -225,33 +226,35 @@ def get_max_pool3d_grad_with_argmax_vmap_rule(prim, axis_size):
|
|
|
225
226
|
return vmap_rule
|
|
226
227
|
|
|
227
228
|
|
|
228
|
-
@vmap_rules_getters.register(
|
|
229
|
+
@vmap_rules_getters.register(gen.CdistGrad)
|
|
229
230
|
def get_cdist_grad_vmap_rule(prim, axis_size):
|
|
230
231
|
"""VmapRule for `cdist grad` operation."""
|
|
231
|
-
if
|
|
232
|
-
batch_rank = prim.batch_rank + 1
|
|
232
|
+
if prim.has_label("batch_rank"):
|
|
233
|
+
batch_rank = prim.get_label("batch_rank") + 1
|
|
233
234
|
else:
|
|
234
235
|
batch_rank = 1
|
|
235
236
|
|
|
236
|
-
|
|
237
|
-
|
|
237
|
+
prim = prim.clone()
|
|
238
|
+
prim.set_label('batch_rank', batch_rank)
|
|
238
239
|
|
|
239
|
-
def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim):
|
|
240
|
-
is_all_none, result = vmap_general_preprocess(
|
|
241
|
-
|
|
240
|
+
def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim):
|
|
241
|
+
is_all_none, result = vmap_general_preprocess(
|
|
242
|
+
prim, grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim
|
|
243
|
+
)
|
|
242
244
|
if is_all_none:
|
|
243
245
|
return result
|
|
244
246
|
grad, grad_dim = grad_bdim
|
|
245
247
|
x, x_dim = x_bdim
|
|
246
248
|
y, y_dim = y_bdim
|
|
247
249
|
cdist, cdist_dim = cdist_bdim
|
|
250
|
+
p, _ = p_bdim
|
|
248
251
|
|
|
249
252
|
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
250
253
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
251
254
|
y = _bdim_at_front(y, y_dim, axis_size)
|
|
252
255
|
cdist = _bdim_at_front(cdist, cdist_dim, axis_size)
|
|
253
256
|
|
|
254
|
-
out =
|
|
257
|
+
out = prim(grad, x, y, cdist, p)
|
|
255
258
|
return out, 0
|
|
256
259
|
|
|
257
260
|
return vmap_rule
|
|
@@ -673,10 +676,11 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
673
676
|
else:
|
|
674
677
|
_raise_value_error("The prim name must be `GridSampler2D` or `GridSampler3D`, but got {}.".format(prim_name))
|
|
675
678
|
|
|
676
|
-
|
|
677
|
-
|
|
679
|
+
def vmap_rule(grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
|
|
680
|
+
output_mask_bdim):
|
|
678
681
|
is_all_none, result = vmap_general_preprocess(
|
|
679
|
-
prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim
|
|
682
|
+
prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
|
|
683
|
+
output_mask_bdim)
|
|
680
684
|
if is_all_none:
|
|
681
685
|
return result
|
|
682
686
|
|
|
@@ -686,6 +690,7 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
686
690
|
interpolation_mode, _ = interpolation_mode_bdim
|
|
687
691
|
padding_mode, _ = padding_mode_bdim
|
|
688
692
|
align_corners, _ = align_corners_bdim
|
|
693
|
+
output_mask, _ = output_mask_bdim
|
|
689
694
|
|
|
690
695
|
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
691
696
|
grad_shape = F.shape(grad)
|
|
@@ -699,7 +704,8 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
|
|
|
699
704
|
grid_shape = F.shape(grid)
|
|
700
705
|
grid = F.reshape(grid, (-1,) + grid_shape[non_batch_dim_index:])
|
|
701
706
|
|
|
702
|
-
dx, dgrid = prim(grad, input_x, grid, interpolation_mode,
|
|
707
|
+
dx, dgrid = prim(grad, input_x, grid, interpolation_mode,
|
|
708
|
+
padding_mode, align_corners, output_mask)
|
|
703
709
|
dx_shape = F.shape(dx)
|
|
704
710
|
dx_return_shape = input_x_shape[:non_batch_dim_index] + dx_shape[non_batch_dim_index:]
|
|
705
711
|
dx = F.reshape(dx, dx_return_shape)
|