mindspore 2.6.0__cp311-cp311-win_amd64.whl → 2.7.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +40 -9
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -1
- mindspore/_extends/parse/parser.py +36 -61
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +32 -13
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +27 -2
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +76 -15
- mindspore/common/api.py +193 -112
- mindspore/common/dtype.py +21 -11
- mindspore/common/dump.py +10 -15
- mindspore/common/generator.py +2 -3
- mindspore/common/hook_handle.py +11 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/parameter.py +26 -12
- mindspore/common/recompute.py +3 -3
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +48 -83
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +38 -23
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/datasets.py +20 -7
- mindspore/dataset/engine/datasets_user_defined.py +32 -2
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +7 -3
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -5
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +1 -0
- mindspore/include/api/cell.h +37 -1
- mindspore/include/api/delegate.h +10 -0
- mindspore/include/api/model.h +3 -0
- mindspore/include/api/types.h +2 -2
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +60 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +4 -44
- mindspore/mint/distributed/__init__.py +1 -0
- mindspore/mint/distributed/distributed.py +208 -5
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +164 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +122 -98
- mindspore/mint/nn/layer/normalization.py +8 -22
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +325 -499
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +32 -34
- mindspore/nn/layer/basic.py +67 -64
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +86 -85
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +37 -39
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +34 -37
- mindspore/nn/wrap/grad_reducer.py +37 -37
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +5 -5
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_vmap/vmap_array_ops.py +6 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +17 -8
- mindspore/ops/auto_generate/gen_extend_func.py +1 -51
- mindspore/ops/auto_generate/gen_ops_def.py +463 -257
- mindspore/ops/auto_generate/gen_ops_prim.py +1127 -885
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +8 -4
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +3 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +7 -94
- mindspore/ops/function/debug_func.py +4 -3
- mindspore/ops/function/grad/grad_func.py +1 -1
- mindspore/ops/function/math_func.py +21 -367
- mindspore/ops/function/nn_func.py +26 -41
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +31 -4
- mindspore/ops/functional.py +0 -2
- mindspore/ops/functional_overload.py +463 -6
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_custom_ops_utils.py +675 -8
- mindspore/ops/operations/_inner_ops.py +3 -6
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/comm_ops.py +185 -26
- mindspore/ops/operations/custom_ops.py +235 -172
- mindspore/ops/operations/debug_ops.py +55 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +15 -16
- mindspore/ops/operations/math_ops.py +3 -4
- mindspore/ops/operations/nn_ops.py +5 -6
- mindspore/ops/primitive.py +6 -10
- mindspore/ops/tensor_method.py +36 -4
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +7 -2
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +4 -2
- mindspore/parallel/_cell_wrapper.py +106 -40
- mindspore/parallel/_parallel_serialization.py +1 -1
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +13 -8
- mindspore/parallel/auto_parallel.py +12 -5
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +3 -1
- mindspore/parallel/cluster/process_entity/_api.py +84 -48
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +43 -4
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +1 -1
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +9 -10
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -6
- mindspore/parallel/shard.py +2 -2
- mindspore/parallel/transform_safetensors.py +462 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +42 -22
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/profiler_context.py +25 -27
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_op_analyse.py +235 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +305 -314
- mindspore/profiler/envprofiler.py +12 -7
- mindspore/profiler/experimental_config.py +96 -6
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/npu_profiler.py +29 -19
- mindspore/profiler/profiler.py +35 -19
- mindspore/profiler/profiler_action_controller.py +64 -76
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +5 -5
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +60 -45
- mindspore/runtime/memory.py +21 -30
- mindspore/runtime/thread_bind_core.py +298 -164
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +6 -2
- mindspore/train/amp.py +43 -20
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_checkpoint.py +3 -6
- mindspore/train/callback/_flops_collector.py +1 -1
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +71 -13
- mindspore/train/data_sink.py +11 -2
- mindspore/train/dataset_helper.py +9 -0
- mindspore/train/model.py +51 -33
- mindspore/train/serialization.py +133 -111
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +162 -78
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +6 -9
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +5 -4
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +352 -390
- mindspore/_deprecated/jit.py +0 -198
- mindspore/experimental/es/__init__.py +0 -22
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,323 +1,193 @@
|
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""functions for mint"""
|
|
1
16
|
import mindspore
|
|
2
|
-
from mindspore import
|
|
3
|
-
from mindspore import
|
|
4
|
-
import mindspore.communication
|
|
5
|
-
import mindspore.communication.comm_func
|
|
17
|
+
from mindspore import ops, mint
|
|
18
|
+
from mindspore import _checkparam as validator
|
|
6
19
|
from mindspore.nn.cell import Cell
|
|
20
|
+
from mindspore.communication.comm_func import all_gather_into_tensor
|
|
21
|
+
from mindspore.communication.comm_func import all_reduce
|
|
22
|
+
from mindspore.communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
|
7
23
|
from mindspore.ops.auto_generate.gen_ops_prim import BatchNormReduceGrad
|
|
8
24
|
from mindspore.ops.auto_generate.gen_ops_prim import BatchNormElemtGrad
|
|
9
|
-
from mindspore.
|
|
10
|
-
from mindspore.ops import ReduceOp
|
|
11
|
-
from mindspore._c_expression import TensorPy as Tensor_
|
|
12
|
-
from mindspore.communication._comm_helper import _get_size_helper, HCCL_WORLD_COMM_GROUP
|
|
13
|
-
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
14
|
-
from mindspore.communication.comm_func import all_gather_into_tensor as all_gather_into_tensor_dy
|
|
15
|
-
from mindspore.ops import operations as P
|
|
16
|
-
from mindspore import ops, mint
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
25
|
+
from mindspore.ops.primitive import Primitive, prim_arg_register, PrimitiveWithInfer, prim_attr_register
|
|
26
|
+
from mindspore.ops.operations.comm_ops import ReduceOp, check_hcom_group_valid, check_collective_target_dtype
|
|
20
27
|
|
|
21
28
|
batch_norm_reduce_grad = BatchNormReduceGrad()
|
|
22
29
|
batch_norm_elemt_grad = BatchNormElemtGrad()
|
|
23
|
-
shape =
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
30
|
+
shape = ops.Shape()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AllGather(PrimitiveWithInfer):
|
|
34
|
+
@prim_arg_register
|
|
35
|
+
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
36
|
+
super(AllGather, self).__init__(self.__class__.__name__)
|
|
37
|
+
self.group = _get_group(group)
|
|
38
|
+
validator.check_value_type('group', self.group, (str,), self.name)
|
|
39
|
+
self.rank = get_rank(self.group)
|
|
40
|
+
self.rank_size = get_group_size(self.group)
|
|
41
|
+
validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
|
|
42
|
+
self.add_prim_attr('rank_size', self.rank_size)
|
|
43
|
+
self.add_prim_attr('group', self.group)
|
|
44
|
+
self.add_prim_attr('fusion', 0)
|
|
45
|
+
self.add_prim_attr('mean_flag', False)
|
|
46
|
+
self.add_prim_attr('no_eliminate', True)
|
|
47
|
+
|
|
48
|
+
def __call__(self, combined):
|
|
49
|
+
output, _ = all_gather_into_tensor(combined, group=self.group)
|
|
31
50
|
return output
|
|
32
51
|
|
|
33
|
-
|
|
52
|
+
def infer_shape(self, x_shape):
|
|
53
|
+
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
|
54
|
+
if x_shape[0] > 0:
|
|
55
|
+
x_shape[0] = x_shape[0] * self.rank_size
|
|
56
|
+
return x_shape
|
|
57
|
+
|
|
58
|
+
def infer_dtype(self, x_dtype):
|
|
59
|
+
check_collective_target_dtype('x', x_dtype, self.name)
|
|
60
|
+
return x_dtype
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class AllReduce(Primitive):
|
|
64
|
+
@prim_attr_register
|
|
65
|
+
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
66
|
+
"""Initialize AllReduce."""
|
|
67
|
+
super().__init__(name="AllReduce")
|
|
68
|
+
self.group = _get_group(group)
|
|
69
|
+
if not isinstance(op, type(ReduceOp.SUM)):
|
|
70
|
+
raise TypeError(f"For '{self.name}', the 'op' must be str, but got {type(op).__name__}.")
|
|
71
|
+
if not isinstance(self.group, str):
|
|
72
|
+
raise TypeError(f"For '{self.name}', the 'group' must be str, "
|
|
73
|
+
f"but got {type(self.group).__name__}.")
|
|
74
|
+
check_hcom_group_valid(self.group, prim_name=self.name)
|
|
75
|
+
self.op = op
|
|
76
|
+
self.add_prim_attr('group', self.group)
|
|
77
|
+
self.add_prim_attr('fusion', 0)
|
|
78
|
+
self.add_prim_attr('index', 0)
|
|
79
|
+
self.add_prim_attr('no_eliminate', True)
|
|
80
|
+
|
|
81
|
+
def __call__(self, combined):
|
|
82
|
+
output, _ = all_reduce(combined, group=self.group)
|
|
34
83
|
return output
|
|
35
|
-
return output
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
39
|
-
if not isinstance(group, str):
|
|
40
|
-
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
|
|
41
|
-
"but got 'group' type : {}.".format(type(group)))
|
|
42
|
-
return _get_size_helper(group=_get_group(group))
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def _contiguous(tensor):
|
|
46
|
-
if not tensor.is_contiguous() or tensor.storage_offset() != 0:
|
|
47
|
-
tensor = tensor.contiguous()
|
|
48
|
-
return tensor
|
|
49
84
|
|
|
50
85
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
|
|
59
|
-
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
60
|
-
raise TypeError(
|
|
61
|
-
"For all_gather_into_tensor, the input tensor must be tensor")
|
|
62
|
-
group = _get_group(group)
|
|
63
|
-
tensor = _contiguous(tensor)
|
|
64
|
-
all_gather_op = _get_cache_prim(P.AllGather)(group=group)
|
|
65
|
-
output = all_gather_op(tensor)
|
|
66
|
-
return _deal_comm_outputs(output, async_op)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
|
|
70
|
-
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
71
|
-
raise TypeError("For all_reduce, the input tensor must be tensor")
|
|
72
|
-
if not isinstance(op, str):
|
|
73
|
-
raise TypeError("For all_reduce, the input op type must be str")
|
|
74
|
-
if op not in ('sum', 'prod', 'min', 'max'):
|
|
75
|
-
raise TypeError(
|
|
76
|
-
"For all_reduce, the input op value must be one of sum, prod, min, max")
|
|
77
|
-
group = _get_group(group)
|
|
78
|
-
tensor = _contiguous(tensor)
|
|
79
|
-
all_reduce_op = _get_cache_prim(P.AllReduce)(op=op, group=group)
|
|
80
|
-
output = all_reduce_op(tensor)
|
|
81
|
-
return _deal_comm_outputs(output, async_op)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def bprop_pynative(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
85
|
-
process_group, world_size, output, doutput):
|
|
86
|
-
_, mean_param, invstd_param, count_all_param = output
|
|
87
|
-
dout, _, _, _ = doutput
|
|
88
|
-
|
|
89
|
-
# 不支持 KBK模式
|
|
90
|
-
if not dout.is_contiguous():
|
|
91
|
-
dout = dout.contiguous()
|
|
92
|
-
|
|
93
|
-
grad_input = grad_weight = grad_bias = None
|
|
94
|
-
|
|
95
|
-
inputG = True
|
|
96
|
-
weightG = True
|
|
97
|
-
biasG = True
|
|
98
|
-
|
|
99
|
-
# calculate local stats as well as grad_weight / grad_bias
|
|
100
|
-
sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
|
|
101
|
-
dout,
|
|
102
|
-
input_x,
|
|
103
|
-
mean_param,
|
|
104
|
-
invstd_param,
|
|
105
|
-
weight,
|
|
106
|
-
inputG,
|
|
107
|
-
weightG,
|
|
108
|
-
biasG
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
if inputG:
|
|
112
|
-
# synchronizing stats used to calculate input gradient.
|
|
113
|
-
sum_dy_shape = shape(sum_dy)
|
|
114
|
-
num_channels = sum_dy_shape[0]
|
|
115
|
-
combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
116
|
-
|
|
117
|
-
new_combined, _ = mindspore.communication.comm_func.all_reduce(
|
|
118
|
-
combined, group=process_group)
|
|
119
|
-
|
|
120
|
-
sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
|
|
121
|
-
|
|
122
|
-
# backward pass for gradient calculation
|
|
123
|
-
grad_input = batch_norm_elemt_grad(
|
|
124
|
-
dout,
|
|
125
|
-
input_x,
|
|
126
|
-
mean_param,
|
|
127
|
-
invstd_param,
|
|
128
|
-
weight,
|
|
129
|
-
sum_dy,
|
|
130
|
-
sum_dy_xmu,
|
|
131
|
-
count_all_param
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
135
|
-
# training would handle all reduce.
|
|
136
|
-
if weight is None or not weightG:
|
|
137
|
-
grad_weight = None
|
|
138
|
-
|
|
139
|
-
if weight is None or not biasG:
|
|
140
|
-
grad_bias = None
|
|
141
|
-
|
|
142
|
-
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
86
|
+
class SyncBatchNormInner(Cell):
|
|
87
|
+
def __init__(self, self_num_features, self_world_size):
|
|
88
|
+
super(SyncBatchNormInner, self).__init__()
|
|
89
|
+
self.num_features = self_num_features
|
|
90
|
+
self.world_size = self_world_size
|
|
143
91
|
|
|
92
|
+
def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
93
|
+
if self.world_size != world_size:
|
|
94
|
+
raise ValueError('World Size Error')
|
|
95
|
+
input = input.contiguous()
|
|
96
|
+
if weight is not None:
|
|
97
|
+
weight = weight.contiguous()
|
|
98
|
+
|
|
99
|
+
input_shape = shape(input)
|
|
100
|
+
input_numel = ops.numel(input)
|
|
101
|
+
size = int(input_numel // input_shape[1])
|
|
102
|
+
if size == 1 and world_size < 2:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
'Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
105
|
+
|
|
106
|
+
# calculate mean/invstd for input.
|
|
107
|
+
mean, invstd = mint.batch_norm_stats(input, eps)
|
|
108
|
+
count = mint.full((1,), input_numel // input_shape[1], dtype=mean.dtype)
|
|
109
|
+
|
|
110
|
+
num_channels = input_shape[1]
|
|
111
|
+
if self.num_features != num_channels:
|
|
112
|
+
raise ValueError('Features Error')
|
|
113
|
+
# C, C, 1 -> (2C + 1)
|
|
114
|
+
combined = mint.cat([mean, invstd, count], dim=0)
|
|
115
|
+
# Use allgather instead of allreduce because count could be different across
|
|
116
|
+
# ranks, simple all reduce op can not give correct results.
|
|
117
|
+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
118
|
+
# all gathered mean, invstd and count.
|
|
119
|
+
# world_size * (2C + 1)
|
|
120
|
+
all_gather_op = AllGather(process_group)
|
|
121
|
+
combined = all_gather_op(combined)
|
|
122
|
+
combined = ops.reshape(combined, [world_size, -1])
|
|
123
|
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
124
|
+
mean_val_all, invstd_val_all, count_val_all = mint.split(
|
|
125
|
+
combined, num_channels, dim=1)
|
|
126
|
+
# calculate global mean & invstd
|
|
127
|
+
mean, invstd = mint.batch_norm_gather_stats_with_counts(input, mean_val_all, invstd_val_all, running_mean,
|
|
128
|
+
running_var, momentum, eps, count_val_all.view(-1))
|
|
129
|
+
|
|
130
|
+
# apply element-wise normalization
|
|
131
|
+
out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
132
|
+
return (out, mean, invstd, count_val_all.view(-1))
|
|
144
133
|
|
|
145
|
-
def
|
|
134
|
+
def bprop(self, input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
146
135
|
process_group, world_size, output, doutput):
|
|
147
|
-
|
|
148
|
-
|
|
136
|
+
_, mean_param, invstd_param, count_all_param = output
|
|
137
|
+
dout, _, _, _ = doutput
|
|
149
138
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
grad_input = grad_weight = grad_bias = None
|
|
153
|
-
|
|
154
|
-
inputG = True
|
|
155
|
-
weightG = True
|
|
156
|
-
biasG = True
|
|
157
|
-
|
|
158
|
-
# calculate local stats as well as grad_weight / grad_bias
|
|
159
|
-
sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
|
|
160
|
-
dout,
|
|
161
|
-
input_x,
|
|
162
|
-
mean_param,
|
|
163
|
-
invstd_param,
|
|
164
|
-
weight,
|
|
165
|
-
inputG,
|
|
166
|
-
weightG,
|
|
167
|
-
biasG
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
if inputG:
|
|
171
|
-
# synchronizing stats used to calculate input gradient.
|
|
172
|
-
sum_dy_shape = shape(sum_dy)
|
|
173
|
-
num_channels = sum_dy_shape[0]
|
|
174
|
-
combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
139
|
+
# 不支持 KBK模式
|
|
140
|
+
dout = dout.contiguous()
|
|
175
141
|
|
|
176
|
-
|
|
142
|
+
grad_input = grad_weight = grad_bias = None
|
|
177
143
|
|
|
178
|
-
|
|
144
|
+
inputG = True
|
|
145
|
+
weightG = True
|
|
146
|
+
biasG = True
|
|
179
147
|
|
|
180
|
-
#
|
|
181
|
-
|
|
148
|
+
# calculate local stats as well as grad_weight / grad_bias
|
|
149
|
+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
|
|
182
150
|
dout,
|
|
183
151
|
input_x,
|
|
184
152
|
mean_param,
|
|
185
153
|
invstd_param,
|
|
186
154
|
weight,
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
155
|
+
inputG,
|
|
156
|
+
weightG,
|
|
157
|
+
biasG
|
|
190
158
|
)
|
|
191
159
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
num_channels = input_shape[1]
|
|
225
|
-
if self_num_features != num_channels:
|
|
226
|
-
raise ValueError('Features Error')
|
|
227
|
-
# C, C, 1 -> (2C + 1)
|
|
228
|
-
combined = mint.cat([mean, invstd, count], dim=0)
|
|
229
|
-
# Use allgather instead of allreduce because count could be different across
|
|
230
|
-
# ranks, simple all reduce op can not give correct results.
|
|
231
|
-
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
232
|
-
# all gathered mean, invstd and count.
|
|
233
|
-
# world_size * (2C + 1)
|
|
234
|
-
combined, _ = all_gather_into_tensor_dy(combined, process_group)
|
|
235
|
-
combined = ops.reshape(combined, [world_size, -1])
|
|
236
|
-
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
237
|
-
mean_val_all, invstd_val_all, count_val_all = mint.split(
|
|
238
|
-
combined, num_channels, dim=1)
|
|
239
|
-
# calculate global mean & invstd
|
|
240
|
-
mean, invstd = mint.batch_norm_gather_stats_with_counts(input, mean_val_all, invstd_val_all, running_mean,
|
|
241
|
-
running_var, momentum, eps, count_val_all.view(-1))
|
|
242
|
-
|
|
243
|
-
# apply element-wise normalization
|
|
244
|
-
out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
245
|
-
return (out, mean, invstd, count_val_all.view(-1))
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
def construct_kbk(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
249
|
-
world_size, self_num_features, self_world_size):
|
|
250
|
-
if self_world_size != world_size:
|
|
251
|
-
raise ValueError('World Size Error')
|
|
252
|
-
input = input.contiguous()
|
|
253
|
-
if weight is not None:
|
|
254
|
-
weight = weight.contiguous()
|
|
255
|
-
|
|
256
|
-
input_shape = shape(input)
|
|
257
|
-
input_numel = ops.numel(input)
|
|
258
|
-
size = int(input_numel // input_shape[1])
|
|
259
|
-
if size == 1 and world_size < 2:
|
|
260
|
-
raise ValueError(
|
|
261
|
-
'Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
262
|
-
|
|
263
|
-
# calculate mean/invstd for input.
|
|
264
|
-
mean, invstd = mint.batch_norm_stats(input, eps)
|
|
265
|
-
count = mint.full((1,), input_numel //
|
|
266
|
-
input_shape[1], dtype=mean.dtype)
|
|
267
|
-
|
|
268
|
-
num_channels = input_shape[1]
|
|
269
|
-
if self_num_features != num_channels:
|
|
270
|
-
raise ValueError('Features Error')
|
|
271
|
-
# C, C, 1 -> (2C + 1)
|
|
272
|
-
combined = mint.cat([mean, invstd, count], dim=0)
|
|
273
|
-
# Use allgather instead of allreduce because count could be different across
|
|
274
|
-
# ranks, simple all reduce op can not give correct results.
|
|
275
|
-
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
276
|
-
# all gathered mean, invstd and count.
|
|
277
|
-
# world_size * (2C + 1)
|
|
278
|
-
combined = all_gather_into_tensor(combined, process_group)
|
|
279
|
-
combined = ops.reshape(combined, [world_size, -1])
|
|
280
|
-
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
281
|
-
mean_all, invstd_all, count_all = mint.split(
|
|
282
|
-
combined, num_channels, dim=1)
|
|
283
|
-
# calculate global mean & invstd
|
|
284
|
-
mean, invstd = mint.batch_norm_gather_stats_with_counts(
|
|
285
|
-
input,
|
|
286
|
-
mean_all,
|
|
287
|
-
invstd_all,
|
|
288
|
-
running_mean,
|
|
289
|
-
running_var,
|
|
290
|
-
momentum,
|
|
291
|
-
eps,
|
|
292
|
-
count_all.view(-1)
|
|
293
|
-
)
|
|
294
|
-
|
|
295
|
-
# apply element-wise normalization
|
|
296
|
-
out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
297
|
-
return (out, mean, invstd, count_all.view(-1))
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
class SyncBatchNormInner(Cell):
|
|
301
|
-
def __init__(self, self_num_features, self_world_size):
|
|
302
|
-
super(SyncBatchNormInner, self).__init__()
|
|
303
|
-
self.num_features = self_num_features
|
|
304
|
-
self.world_size = self_world_size
|
|
305
|
-
self.mode = context.get_context("mode")
|
|
306
|
-
if self.mode == 1:
|
|
307
|
-
self.fn_bprop = bprop_pynative
|
|
308
|
-
self.fn_construct = construct_pynative
|
|
309
|
-
else:
|
|
310
|
-
self.fn_bprop = bprop_kbk
|
|
311
|
-
self.fn_construct = construct_kbk
|
|
312
|
-
|
|
313
|
-
def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
314
|
-
return self.fn_construct(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
315
|
-
world_size, self.num_features, self.world_size)
|
|
316
|
-
|
|
317
|
-
def bprop(self, input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
318
|
-
process_group, world_size, output, doutput):
|
|
319
|
-
return self.fn_bprop(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
320
|
-
process_group, world_size, output, doutput)
|
|
160
|
+
if inputG:
|
|
161
|
+
# synchronizing stats used to calculate input gradient.
|
|
162
|
+
sum_dy_shape = shape(sum_dy)
|
|
163
|
+
num_channels = sum_dy_shape[0]
|
|
164
|
+
combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
165
|
+
all_reduce_op = AllReduce(group=process_group)
|
|
166
|
+
new_combined = all_reduce_op(combined)
|
|
167
|
+
|
|
168
|
+
sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
|
|
169
|
+
|
|
170
|
+
# backward pass for gradient calculation
|
|
171
|
+
grad_input = batch_norm_elemt_grad(
|
|
172
|
+
dout,
|
|
173
|
+
input_x,
|
|
174
|
+
mean_param,
|
|
175
|
+
invstd_param,
|
|
176
|
+
weight,
|
|
177
|
+
sum_dy,
|
|
178
|
+
sum_dy_xmu,
|
|
179
|
+
count_all_param
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
183
|
+
# training would handle all reduce.
|
|
184
|
+
if weight is None or not weightG:
|
|
185
|
+
grad_weight = None
|
|
186
|
+
|
|
187
|
+
if weight is None or not biasG:
|
|
188
|
+
grad_bias = None
|
|
189
|
+
|
|
190
|
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
321
191
|
|
|
322
192
|
|
|
323
193
|
class _SyncBatchNorm(Cell):
|
|
@@ -45,6 +45,10 @@ class SiLU(Cell):
|
|
|
45
45
|
.. warning::
|
|
46
46
|
This is an experimental API that is subject to change or deletion.
|
|
47
47
|
|
|
48
|
+
Args:
|
|
49
|
+
inplace (bool, optional): If it is ``True``, enable the in-place update function.
|
|
50
|
+
Default value: ``False``.
|
|
51
|
+
|
|
48
52
|
Inputs:
|
|
49
53
|
- **input** (Tensor) - `input` is :math:`x` in the preceding formula.
|
|
50
54
|
Input with the data type float16 or float32. Tensor of any dimension.
|
|
@@ -63,18 +67,19 @@ class SiLU(Cell):
|
|
|
63
67
|
>>> from mindspore import Tensor, mint
|
|
64
68
|
>>> import numpy as np
|
|
65
69
|
>>> input = Tensor(np.array([-1, 2, -3, 2, -1]), mindspore.float16)
|
|
66
|
-
>>> silu = mint.nn.SiLU()
|
|
70
|
+
>>> silu = mint.nn.SiLU(inplace=False)
|
|
67
71
|
>>> output = silu(input)
|
|
68
72
|
>>> print(output)
|
|
69
73
|
[-0.269 1.762 -0.1423 1.762 -0.269]
|
|
70
74
|
"""
|
|
71
75
|
|
|
72
|
-
def __init__(self):
|
|
76
|
+
def __init__(self, inplace=False):
|
|
73
77
|
"""Initialize SiLU."""
|
|
74
78
|
super(SiLU, self).__init__()
|
|
79
|
+
self.inplace = inplace
|
|
75
80
|
|
|
76
81
|
def construct(self, x):
|
|
77
|
-
return mint.nn.functional.silu(x)
|
|
82
|
+
return mint.nn.functional.silu(x, self.inplace)
|
|
78
83
|
|
|
79
84
|
|
|
80
85
|
class Sigmoid(Cell):
|
|
@@ -355,9 +360,6 @@ class Threshold(Cell):
|
|
|
355
360
|
\text{value}, &\text{ otherwise }
|
|
356
361
|
\end{cases}
|
|
357
362
|
|
|
358
|
-
.. warning::
|
|
359
|
-
This is an experimental API that is subject to change or deletion.
|
|
360
|
-
|
|
361
363
|
Args:
|
|
362
364
|
threshold (Union[int, float]): The value of the threshold.
|
|
363
365
|
value (Union[int, float]): The value to replace with when element is less than threshold.
|