mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Parameter broadcast"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
__all__ = ["parameter_broadcast"]
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import mindspore as ms
|
|
22
|
+
from mindspore.communication import get_rank, create_group, get_group_size
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parameter_broadcast(net, layout, cur_rank=0, initial_rank=0):
|
|
26
|
+
"""
|
|
27
|
+
Broadcast parameter to other rank in data parallel dimension.
|
|
28
|
+
|
|
29
|
+
.. warning::
|
|
30
|
+
This is an experimental API that is subject to change or deletion.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
net (Cell): The network where the parameters will be broadcasted.
|
|
34
|
+
layout (Dict): Parameter layout dictionary. Come from
|
|
35
|
+
:func:`mindspore.nn.Cell.parameter_layout_dict`
|
|
36
|
+
or read from file(for example: "strategy.ckpt" saved by using the
|
|
37
|
+
`strategy_ckpt_config` parameter of :func:`mindspore.set_auto_parallel_context`).
|
|
38
|
+
The key is param name, the value is the layout of this parameter.
|
|
39
|
+
cur_rank (int, optional): current rank id. Default: ``0``.
|
|
40
|
+
initial_rank (int, optional): Start rank id for each pipeline. Default: ``0``.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: `cur_rank` is not rank id of current rank.
|
|
44
|
+
ValueError: `initial_rank` is not the start rank id of current pipeline stage.
|
|
45
|
+
ValueError: Parameter name in `layout` can not be found in
|
|
46
|
+
:func:`mindspore.nn.Cell.parameters_dict`.
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
>>> import os
|
|
50
|
+
>>> import mindspore as ms
|
|
51
|
+
>>> import mindspore.dataset as ds
|
|
52
|
+
>>> from mindspore import nn, ops
|
|
53
|
+
>>> from mindspore.communication import init
|
|
54
|
+
>>> from mindspore.common.initializer import initializer
|
|
55
|
+
>>> from mindspore.train import Model
|
|
56
|
+
>>> from mindspore.parallel.parameter_broadcast import parameter_broadcast
|
|
57
|
+
>>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
58
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
59
|
+
>>> ms.set_context(max_device_memory="28GB")
|
|
60
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
61
|
+
>>> init()
|
|
62
|
+
>>> ms.set_seed(1)
|
|
63
|
+
>>> class Network(nn.Cell):
|
|
64
|
+
... def __init__(self):
|
|
65
|
+
... super().__init__()
|
|
66
|
+
... self.flatten = ops.Flatten()
|
|
67
|
+
... self.fc1_weight = ms.Parameter(initializer("normal", [28*28, 512], ms.float32))
|
|
68
|
+
... self.fc2_weight = ms.Parameter(initializer("normal", [512, 512], ms.float32))
|
|
69
|
+
... self.fc3_weight = ms.Parameter(initializer("normal", [512, 10], ms.float32))
|
|
70
|
+
... self.matmul1 = ops.MatMul()
|
|
71
|
+
... self.relu1 = ops.ReLU()
|
|
72
|
+
... self.matmul2 = ops.MatMul()
|
|
73
|
+
... self.relu2 = ops.ReLU()
|
|
74
|
+
... self.matmul3 = ops.MatMul()
|
|
75
|
+
... def construct(self, x):
|
|
76
|
+
... x = self.flatten(x)
|
|
77
|
+
... x = self.matmul1(x, self.fc1_weight)
|
|
78
|
+
... x = self.relu1(x)
|
|
79
|
+
... x = self.matmul2(x, self.fc2_weight)
|
|
80
|
+
... x = self.relu2(x)
|
|
81
|
+
... logits = self.matmul3(x, self.fc3_weight)
|
|
82
|
+
... return logits
|
|
83
|
+
>>> net = Network()
|
|
84
|
+
>>> net.matmul1.shard(((2, 4), (4, 1)))
|
|
85
|
+
>>> net.relu1.shard(((4, 1),))
|
|
86
|
+
>>> net.matmul2.shard(((1, 8), (8, 1)))
|
|
87
|
+
>>> net.relu2.shard(((8, 1),))
|
|
88
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
89
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
90
|
+
>>> dataset = create_dataset()
|
|
91
|
+
>>> optim = nn.SGD(net.trainable_params(), 1e-2)
|
|
92
|
+
>>> loss = nn.CrossEntropyLoss()
|
|
93
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
|
94
|
+
>>> model.train(1, dataset)
|
|
95
|
+
>>> ms.save_checkpoint(net, "./simple.ckpt", False)
|
|
96
|
+
>>> layout = model.train_network.parameter_layout_dict
|
|
97
|
+
>>> param_dict = load_checkpoint("./simple.ckpt")
|
|
98
|
+
>>> load_param_into_net(net, param_dict)
|
|
99
|
+
>>> rank_id = os.environ["RANK_ID"]
|
|
100
|
+
>>> parameter_broadcast(model.train_network, layout, int(rank_id), 0)
|
|
101
|
+
>>> class LossCallBack(Callback):
|
|
102
|
+
... def step_end(self, run_context):
|
|
103
|
+
... cb_params = run_context.original_args()
|
|
104
|
+
... print("step end, cur step num: ", cb_params.cur_step_num, flush=True)
|
|
105
|
+
>>> model.train(1, dataset, callbacks=[LossCallBack()])
|
|
106
|
+
"""
|
|
107
|
+
if not layout:
|
|
108
|
+
return
|
|
109
|
+
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
110
|
+
from mindspore.nn.wrap.cell_wrapper import AllreduceGraph
|
|
111
|
+
origin_parallel_mode = ms.get_auto_parallel_context("parallel_mode")
|
|
112
|
+
if origin_parallel_mode not in ("semi_auto_parallel", "auto_parallel"):
|
|
113
|
+
return
|
|
114
|
+
if cur_rank != get_rank():
|
|
115
|
+
raise ValueError(f"For parameter broadcast, the cur_rank: {cur_rank} is wrong.")
|
|
116
|
+
if initial_rank % (get_group_size() / ms.get_auto_parallel_context("pipeline_stages")) != 0:
|
|
117
|
+
raise ValueError(f"For parameter broadcast, the initial_rank: {initial_rank} is wrong.")
|
|
118
|
+
param_redundancy = get_parameter_redundancy(layout, initial_rank)
|
|
119
|
+
if not param_redundancy:
|
|
120
|
+
return
|
|
121
|
+
single_params = remove_param_redundancy(param_redundancy)
|
|
122
|
+
if not single_params:
|
|
123
|
+
return
|
|
124
|
+
param_redundancy_reversed = {}
|
|
125
|
+
for key, redundancy in param_redundancy.items():
|
|
126
|
+
for item in redundancy:
|
|
127
|
+
if len(item) == 1:
|
|
128
|
+
continue
|
|
129
|
+
if cur_rank in item:
|
|
130
|
+
param_redundancy_reversed.setdefault(item, []).append(key)
|
|
131
|
+
if not param_redundancy_reversed:
|
|
132
|
+
return
|
|
133
|
+
if cur_rank not in single_params:
|
|
134
|
+
return
|
|
135
|
+
net_param_dict = net.parameters_dict()
|
|
136
|
+
ms.set_auto_parallel_context(parallel_mode="hybrid_parallel")
|
|
137
|
+
for group, params in param_redundancy_reversed.items():
|
|
138
|
+
create_group(str(group), list(group))
|
|
139
|
+
allreduce_input = []
|
|
140
|
+
for param in params:
|
|
141
|
+
if param not in net_param_dict:
|
|
142
|
+
raise ValueError(f"For parameter broadcast, the param: {param} can not be found.")
|
|
143
|
+
real_param = net_param_dict[param]
|
|
144
|
+
if param not in single_params[cur_rank]:
|
|
145
|
+
real_param.set_data(ms.Tensor(np.zeros(real_param.shape), dtype=real_param.dtype))
|
|
146
|
+
allreduce_input.append(real_param)
|
|
147
|
+
if not allreduce_input:
|
|
148
|
+
continue
|
|
149
|
+
allreduce_graph = AllreduceGraph(allreduce_input, str(group))
|
|
150
|
+
allreduce_graph()
|
|
151
|
+
ms.set_auto_parallel_context(parallel_mode=origin_parallel_mode)
|
mindspore/parallel/shard.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,13 +12,120 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""shard"""
|
|
16
16
|
|
|
17
|
+
import copy
|
|
17
18
|
import mindspore as ms
|
|
18
19
|
from mindspore import log as logger
|
|
19
20
|
from mindspore._c_expression import Shard_
|
|
20
21
|
|
|
21
22
|
|
|
23
|
+
class Layout:
|
|
24
|
+
"""
|
|
25
|
+
Parallel layout describes the detailed sharding information.
|
|
26
|
+
|
|
27
|
+
Note:
|
|
28
|
+
- It is valid only in semi auto parallel or auto parallel mode.
|
|
29
|
+
- The multiplication result of the `device_matrix` must be equal to the device count in a pipeline stage.
|
|
30
|
+
- When the layout function is invoked to constructs a sharding strategy, each alias name is only allowed to be
|
|
31
|
+
used once to shard a tensor.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
device_matrix (tuple): Describe the shape of devices arrangement, its element type is int.
|
|
35
|
+
alias_name (tuple): The alias name for each axis of device_matrix, its length shoits element type is string.
|
|
36
|
+
When using "interleaved_parallel" as an alias name, the tensor would be split into multiple
|
|
37
|
+
copies on the corresponding partition dimension on a single card.
|
|
38
|
+
Raises:
|
|
39
|
+
TypeError: `device_matrix` is not a tuple type.
|
|
40
|
+
TypeError: `alias_name` is not a tuple type.
|
|
41
|
+
ValueError: `device_matrix` length is not equal to `alias_name` length.
|
|
42
|
+
TypeError: The element of `device_matrix` is not int type.
|
|
43
|
+
TypeError: The element of `alias_name` is not a str type.
|
|
44
|
+
ValueError: The element of `alias_name` is an empty str.
|
|
45
|
+
ValueError: The element of `alias_name` is "None".
|
|
46
|
+
ValueError: `alias_name` contains repeated element.
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
>>> from mindspore import Layout
|
|
50
|
+
>>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
|
51
|
+
>>> layout0 = layout("dp", "mp")
|
|
52
|
+
>>> print(layout0.to_dict())
|
|
53
|
+
{"device_matrix": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False}
|
|
54
|
+
>>> # Total device num is 4, but split the tensor in local device into two copies.
|
|
55
|
+
>>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel"))
|
|
56
|
+
>>> layout1 = layout(("dp", "interleaved_parallel"), "sp")
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, device_matrix, alias_name):
|
|
60
|
+
if not isinstance(device_matrix, tuple):
|
|
61
|
+
raise TypeError(f'device_matrix must be tuple type, but got:{type(device_matrix)}')
|
|
62
|
+
if not isinstance(alias_name, tuple):
|
|
63
|
+
raise TypeError(f'alias_name must be tuple type, but got:{type(alias_name)}')
|
|
64
|
+
if len(device_matrix) != len(alias_name):
|
|
65
|
+
raise ValueError(f'device_matrix length should be equal to alias_name length')
|
|
66
|
+
for in_ele in device_matrix:
|
|
67
|
+
if not isinstance(in_ele, int):
|
|
68
|
+
raise TypeError(f'The element of device_matrix must be int type, but got:{type(in_ele)}')
|
|
69
|
+
for in_ele in alias_name:
|
|
70
|
+
if not isinstance(in_ele, str):
|
|
71
|
+
raise TypeError(f'The element of alias_name must be str type, but got:{type(in_ele)}')
|
|
72
|
+
if not in_ele:
|
|
73
|
+
raise ValueError(f"The element of alias_name can not be empty.")
|
|
74
|
+
if in_ele == "None":
|
|
75
|
+
raise ValueError(f"The element of alias_name can not set 'None', because 'None' means no sharding.")
|
|
76
|
+
if len(set(alias_name)) != len(alias_name):
|
|
77
|
+
raise ValueError(f'Each element of alias_name {alias_name} should be different')
|
|
78
|
+
inter_key = "interleaved_parallel"
|
|
79
|
+
if inter_key in alias_name and alias_name.index(inter_key) != len(alias_name) - 1:
|
|
80
|
+
raise ValueError(f"When alias_name {alias_name} contains keyword 'interleaved_parallel',"
|
|
81
|
+
f" it should be at the last dim of alias_name, which means the virtual sharding.")
|
|
82
|
+
self._device_shape = device_matrix
|
|
83
|
+
self._alias_name = alias_name
|
|
84
|
+
self._tensor_map = None
|
|
85
|
+
|
|
86
|
+
def __call__(self, *tensor_map):
|
|
87
|
+
self._tensor_map = ()
|
|
88
|
+
writed_map = ()
|
|
89
|
+
for ele in tensor_map:
|
|
90
|
+
if isinstance(ele, tuple):
|
|
91
|
+
ele_map = ()
|
|
92
|
+
for item in ele:
|
|
93
|
+
if item == "None":
|
|
94
|
+
ele_map += (-1,)
|
|
95
|
+
continue
|
|
96
|
+
if item not in self._alias_name:
|
|
97
|
+
raise ValueError(f'The axis {item} is not found in {self._alias_name}')
|
|
98
|
+
if item in writed_map:
|
|
99
|
+
raise ValueError(f'The axis {item} has been set more than one in {self._alias_name}')
|
|
100
|
+
ele_map += (len(self._alias_name) - 1 - self._alias_name.index(item),)
|
|
101
|
+
writed_map += (item,)
|
|
102
|
+
self._tensor_map += (ele_map,)
|
|
103
|
+
continue
|
|
104
|
+
if ele == "None":
|
|
105
|
+
self._tensor_map += (-1,)
|
|
106
|
+
continue
|
|
107
|
+
if ele not in self._alias_name:
|
|
108
|
+
raise ValueError(f'The axis {ele} is not found in {self._alias_name}')
|
|
109
|
+
if ele in writed_map:
|
|
110
|
+
raise ValueError(f'The axis {ele} has been set more than one in {self._alias_name}')
|
|
111
|
+
self._tensor_map += (len(self._alias_name) - 1 - self._alias_name.index(ele),)
|
|
112
|
+
writed_map += (ele,)
|
|
113
|
+
return copy.deepcopy(self)
|
|
114
|
+
|
|
115
|
+
def to_dict(self):
|
|
116
|
+
"""
|
|
117
|
+
Transform layout to a dictionary.
|
|
118
|
+
"""
|
|
119
|
+
if self._device_shape is None:
|
|
120
|
+
raise ValueError("The device_shape of layout is None")
|
|
121
|
+
if self._tensor_map is None:
|
|
122
|
+
raise ValueError("The tensor_map of layout is None")
|
|
123
|
+
interleaved_parallel = "interleaved_parallel" in self._alias_name
|
|
124
|
+
return {"device_matrix": self._device_shape, "tensor_map": self._tensor_map,
|
|
125
|
+
"interleaved_parallel": interleaved_parallel, "alias_name": self._alias_name}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
|
|
22
129
|
class Shard(Shard_):
|
|
23
130
|
"""Shard operation"""
|
|
24
131
|
|
|
@@ -34,22 +141,33 @@ class Shard(Shard_):
|
|
|
34
141
|
self.level = None
|
|
35
142
|
|
|
36
143
|
def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
37
|
-
|
|
38
|
-
|
|
144
|
+
parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
|
|
145
|
+
if parallel_mode not in ("auto_parallel", "semi_auto_parallel"):
|
|
39
146
|
raise AssertionError(
|
|
40
|
-
f"Cell shard only supports auto parallel
|
|
41
|
-
if ms.context.get_context("device_target") not in
|
|
147
|
+
f"Cell shard only supports auto parallel and semi auto parallel.")
|
|
148
|
+
if ms.context.get_context("device_target") not in ("Ascend", "GPU"):
|
|
42
149
|
raise AssertionError(
|
|
43
150
|
f"'Shard' now only supports 'Ascend' and 'GPU'")
|
|
44
|
-
if
|
|
45
|
-
|
|
46
|
-
|
|
151
|
+
if parallel_mode == "auto_parallel" and \
|
|
152
|
+
ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
|
|
153
|
+
raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard' when the "
|
|
154
|
+
f"'parallel_mode' is 'auto_parallel.'")
|
|
155
|
+
|
|
47
156
|
if not isinstance(in_strategy, tuple):
|
|
48
157
|
raise TypeError(
|
|
49
|
-
f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
|
|
158
|
+
f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}.")
|
|
159
|
+
inner_type = self._check_layout_inner_type(in_strategy, "in_strategy")
|
|
160
|
+
if inner_type == "layout":
|
|
161
|
+
in_strategy = self._extract_layout_value(in_strategy, "in_strategy")
|
|
162
|
+
|
|
50
163
|
if not isinstance(out_strategy, (type(None), tuple)):
|
|
51
164
|
raise TypeError(f"For 'Shard', the 'out_strategy' should be None or tuple, "
|
|
52
|
-
f"but got {type(out_strategy).__name__}")
|
|
165
|
+
f"but got {type(out_strategy).__name__}.")
|
|
166
|
+
if not isinstance(out_strategy, type(None)):
|
|
167
|
+
logger.warning("Out_strategy is not in use currently, will be ignored in the following procedures.")
|
|
168
|
+
inner_type = self._check_layout_inner_type(out_strategy, "out_strategy")
|
|
169
|
+
if inner_type == "layout":
|
|
170
|
+
out_strategy = self._extract_layout_value(out_strategy, "out_strategy")
|
|
53
171
|
|
|
54
172
|
if not isinstance(device, str):
|
|
55
173
|
raise TypeError(f"For 'Shard', the 'device' should be a string, "
|
|
@@ -125,9 +243,9 @@ class Shard(Shard_):
|
|
|
125
243
|
f"If parameter_plan is set, type of fn must be mindspore.nn.Cell, but got {type(fn)}")
|
|
126
244
|
for k in parameter_plan.keys():
|
|
127
245
|
v = parameter_plan[k]
|
|
128
|
-
if not isinstance(k, str) or not isinstance(v, tuple):
|
|
246
|
+
if not isinstance(k, str) or not isinstance(v, (tuple, Layout)):
|
|
129
247
|
raise TypeError(f"For 'Shard', the type of each key and value in 'parameter_plan' must be str and "
|
|
130
|
-
f"tuple, but got {type(k).__name__} and {type(v).__name__}")
|
|
248
|
+
f"tuple/Layout, but got {type(k).__name__} and {type(v).__name__}")
|
|
131
249
|
else:
|
|
132
250
|
raise TypeError(f"For 'Shard', the 'parameter_plan' should be a dict or None, "
|
|
133
251
|
f"but got {type(parameter_plan).__name__}")
|
|
@@ -140,24 +258,75 @@ class Shard(Shard_):
|
|
|
140
258
|
f"{param_name} is not exist, ignored its setting.")
|
|
141
259
|
continue
|
|
142
260
|
|
|
143
|
-
|
|
144
|
-
param_name, param.shape, param_strategy)
|
|
261
|
+
has_set = None
|
|
145
262
|
if param.param_info.param_strategy:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
263
|
+
has_set = "strategy"
|
|
264
|
+
if param.param_info.device_matrix:
|
|
265
|
+
has_set = "layout"
|
|
266
|
+
if has_set == "strategy":
|
|
267
|
+
logger.warning(f"The layout of parameter '{param_name}' has been set to "
|
|
268
|
+
f"{param.param_info.param_strategy}, current setting will be ignored.")
|
|
269
|
+
elif has_set == "layout":
|
|
270
|
+
logger.warning(f"The layout of parameter '{param_name}' has been set, "
|
|
271
|
+
f"current setting will be ignored.")
|
|
272
|
+
else:
|
|
273
|
+
if isinstance(param_strategy, tuple):
|
|
274
|
+
self._check_layout_is_valid(param_name, param.shape, param_strategy)
|
|
275
|
+
param.param_info.param_strategy = param_strategy
|
|
276
|
+
if isinstance(param_strategy, Layout):
|
|
277
|
+
param_layout = self._extract_layout_value((param_strategy,), "in_strategy")[0]
|
|
278
|
+
param.param_info.device_matrix = param_layout["device_matrix"]
|
|
279
|
+
param.param_info.tensor_map = param_layout["tensor_map"]
|
|
280
|
+
param.param_info.interleaved_parallel = param_layout["interleaved_parallel"]
|
|
281
|
+
param.param_info.alias_name = param_layout["alias_name"]
|
|
150
282
|
|
|
151
283
|
def _is_attrs_has_been_set(self, fn, in_strategy, out_strategy, device, level):
|
|
152
284
|
return self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
|
|
153
285
|
self.out_strategy == out_strategy and self.device == device and self.level == level
|
|
154
286
|
|
|
287
|
+
def _check_layout_inner_type(self, strategy, log_info):
|
|
288
|
+
"""Check inner item type of layout, should be int or ms.Layout."""
|
|
289
|
+
strategy_set = set()
|
|
290
|
+
for stra in strategy:
|
|
291
|
+
if not isinstance(stra, (tuple, Layout)):
|
|
292
|
+
raise TypeError(
|
|
293
|
+
f"The '{log_info}' should be a tuple(tuple(int)) or tuple(mindspore.Layout), "
|
|
294
|
+
f"but got {type(stra).__name__}")
|
|
295
|
+
if isinstance(stra, Layout):
|
|
296
|
+
strategy_set.add("layout")
|
|
297
|
+
elif isinstance(stra, tuple):
|
|
298
|
+
strategy_set.add("tuple")
|
|
299
|
+
self._check_tuple_strategy(stra)
|
|
300
|
+
if len(strategy_set) != 1:
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"For 'Shard', the strategy can only pass in consistent type for all dimensions.")
|
|
303
|
+
return strategy_set.pop()
|
|
304
|
+
|
|
305
|
+
def _extract_layout_value(self, layout, log_info):
|
|
306
|
+
"""Extract parallel layout value"""
|
|
307
|
+
layout_value = None
|
|
308
|
+
if layout is not None:
|
|
309
|
+
if not isinstance(layout, tuple):
|
|
310
|
+
raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}')
|
|
311
|
+
layout_value = ()
|
|
312
|
+
for in_ele in layout:
|
|
313
|
+
if not isinstance(in_ele, Layout):
|
|
314
|
+
raise TypeError(f"The {log_info} item should be a object of class Layout.")
|
|
315
|
+
layout_value += (in_ele.to_dict(),)
|
|
316
|
+
return layout_value
|
|
317
|
+
|
|
318
|
+
def _check_tuple_strategy(self, dim_strategy):
|
|
319
|
+
if not all(isinstance(x, int) for x in dim_strategy):
|
|
320
|
+
raise TypeError(
|
|
321
|
+
f"The tuple strategy for each dimension should be tuple(int).")
|
|
322
|
+
|
|
155
323
|
|
|
156
324
|
def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
157
325
|
"""
|
|
158
326
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
159
|
-
generated by sharding propagation. In PyNative mode, use this method
|
|
160
|
-
to specify
|
|
327
|
+
generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
|
|
328
|
+
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
329
|
+
strategy for others will be set by sharding propagation.
|
|
161
330
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
162
331
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
163
332
|
this input/output, and None represents data_parallel,
|
|
@@ -165,9 +334,8 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
165
334
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
166
335
|
|
|
167
336
|
Note:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
and the search mode (search_mode) to "sharding_propagation".
|
|
337
|
+
If ms.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
|
|
338
|
+
"auto_parallel" and the search mode (search_mode) to "sharding_propagation".
|
|
171
339
|
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
172
340
|
|
|
173
341
|
Args:
|
|
@@ -175,15 +343,16 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
175
343
|
Its arguments and return value must be Tensor or Parameter.
|
|
176
344
|
If `fn` is a Cell with parameters, `fn` needs to be an instantiated object,
|
|
177
345
|
otherwise its arguments cannot be accessed.
|
|
178
|
-
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or
|
|
179
|
-
|
|
180
|
-
|
|
346
|
+
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple(int) or
|
|
347
|
+
tuple(mindspore.Layout).
|
|
348
|
+
Tuple defines the layout of the corresponding input.
|
|
181
349
|
out_strategy (Union[tuple, None]): Define the layout of outputs similar with `in_strategy`.
|
|
182
350
|
It is not in use right now. Default: ``None`` .
|
|
183
351
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
184
352
|
defines the layout of the parameter like "param_name: layout".
|
|
185
353
|
The key is a parameter name of type 'str'.
|
|
186
|
-
The value is a 1-D integer tuple
|
|
354
|
+
The value is a 1-D integer tuple or a 1-D mindspore.Layout tuple,
|
|
355
|
+
indicating the corresponding layout.
|
|
187
356
|
If the parameter name is incorrect or the corresponding parameter
|
|
188
357
|
has been set, the parameter setting will be ignored.
|
|
189
358
|
Default: ``None`` .
|
|
@@ -197,15 +366,15 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
197
366
|
Function, return the function that will be executed under auto parallel process.
|
|
198
367
|
|
|
199
368
|
Raises:
|
|
200
|
-
AssertionError: If
|
|
201
|
-
AssertionError: If parallel mode is not "auto_parallel".
|
|
202
|
-
AssertionError: If search_mode it not "sharding_propagation".
|
|
369
|
+
AssertionError: If parallel mode is not "auto_parallel" nor "semi_auto_parallel".
|
|
203
370
|
AssertionError: If device_target it not "Ascend" or "GPU".
|
|
204
371
|
TypeError: If `in_strategy` is not a tuple.
|
|
205
372
|
TypeError: If `out_strategy` is not a tuple or None.
|
|
373
|
+
TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.Layout).
|
|
374
|
+
TypeError: If any element in `out_strategy` is not a tuple(int) or tuple(mindspore.Layout).
|
|
206
375
|
TypeError: If `parameter_plan` is not a dict or None.
|
|
207
376
|
TypeError: If any key in `parameter_plan` is not a str.
|
|
208
|
-
TypeError: If any value in `parameter_plan` is not a tuple.
|
|
377
|
+
TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.Layout).
|
|
209
378
|
TypeError: If `device` is not a str.
|
|
210
379
|
TypeError: If `level` is not an integer.
|
|
211
380
|
|
|
@@ -215,23 +384,96 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
215
384
|
Examples:
|
|
216
385
|
>>> import numpy as np
|
|
217
386
|
>>> import mindspore as ms
|
|
218
|
-
>>> from mindspore import Tensor
|
|
387
|
+
>>> from mindspore import Tensor, nn
|
|
219
388
|
>>> from mindspore.communication import init
|
|
220
|
-
>>> ms.set_context(mode=ms.
|
|
389
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
221
390
|
>>> init()
|
|
222
391
|
>>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation",
|
|
223
|
-
...
|
|
392
|
+
... device_num=8)
|
|
393
|
+
>>>
|
|
394
|
+
>>> # Case 1: cell uses functional
|
|
395
|
+
>>> class BasicBlock(nn.Cell):
|
|
396
|
+
>>> def __init__(self):
|
|
397
|
+
>>> super(BasicBlock, self).__init__()
|
|
398
|
+
>>> self.dense1 = nn.Dense(64, 64)
|
|
399
|
+
>>> self.gelu = nn.GELU()
|
|
400
|
+
>>> def my_add(x, y):
|
|
401
|
+
>>> x = ops.abs(x)
|
|
402
|
+
>>> return x + y
|
|
403
|
+
>>> # shard a function with tuple(int) strategies
|
|
404
|
+
>>> self.shard_my_add = ms.shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
|
|
405
|
+
>>>
|
|
406
|
+
>>> def construct(self, x, u):
|
|
407
|
+
>>> x = self.gelu(x)
|
|
408
|
+
>>> y = self.gelu(u)
|
|
409
|
+
>>> y = x * y
|
|
410
|
+
>>> x = self.dense1(x)
|
|
411
|
+
>>> x = self.shard_my_add(x, y)
|
|
412
|
+
>>> return x
|
|
413
|
+
>>>
|
|
414
|
+
>>> class NetForward(nn.Cell):
|
|
415
|
+
>>> def __init__(self):
|
|
416
|
+
>>> super(NetForward, self).__init__()
|
|
417
|
+
>>> self.block1 = BasicBlock()
|
|
418
|
+
>>> self.block2 = BasicBlock()
|
|
419
|
+
>>> self.matmul = ops.MatMul()
|
|
420
|
+
>>>
|
|
421
|
+
>>> def construct(self, x, y):
|
|
422
|
+
>>> x = self.matmul(x, y)
|
|
423
|
+
>>> x = self.block1(x, x)
|
|
424
|
+
>>> x = self.block2(x, x)
|
|
425
|
+
>>> return x
|
|
426
|
+
>>>
|
|
427
|
+
>>> class Net(nn.Cell):
|
|
428
|
+
>>> def __init__(self):
|
|
429
|
+
>>> super(Net, self).__init__()
|
|
430
|
+
>>> # setting cell sharding strategy and parameter_plan by tuple(int)
|
|
431
|
+
>>> self.layer_net1 = NetForward()
|
|
432
|
+
>>> self.layer_net1_shard = ms.shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
|
|
433
|
+
... parameter_plan={"self.layer_net1.block1.weight": (4, 1)})
|
|
434
|
+
>>>
|
|
435
|
+
>>> # setting cell sharding strategy and parameter_plan by tuple(ms.Layout)
|
|
436
|
+
>>> self.layer_net2 = NetForward()
|
|
437
|
+
>>> layout = Layout((4, 2, 1), ("dp", "mp", "sp"))
|
|
438
|
+
>>> in_layout = (layout("dp", "mp"), layout("mp", "sp"))
|
|
439
|
+
>>> param_layout = layout("dp", "sp")
|
|
440
|
+
>>> self.layer_net2_shard = ms.shard(self.layer_net2, in_strategy=in_layout,
|
|
441
|
+
... parameter_plan={"self.layer_net2.block2.weight": param_layout})
|
|
442
|
+
>>> self.flatten = nn.Flatten()
|
|
443
|
+
>>> self.layer1 = nn.Dense(64, 64)
|
|
444
|
+
>>> self.layer2 = nn.Dense(64, 32)
|
|
445
|
+
>>> self.add = ops.Add()
|
|
446
|
+
>>> self.matmul = ops.MatMul()
|
|
447
|
+
>>>
|
|
448
|
+
>>> def construct(self, x, y):
|
|
449
|
+
>>> x = self.flatten(x)
|
|
450
|
+
>>> y = self.flatten(y)
|
|
451
|
+
>>> x = self.layer1(x)
|
|
452
|
+
>>> x = self.layer_net1_shard(x, y)
|
|
453
|
+
>>> x = self.layer_net2_shard(x, y)
|
|
454
|
+
>>> x = self.layer2(x)
|
|
455
|
+
>>> x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32))
|
|
456
|
+
>>> return x
|
|
457
|
+
>>>
|
|
458
|
+
>>> net = Net()
|
|
459
|
+
>>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
|
|
460
|
+
>>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
|
|
461
|
+
>>> net(x, y)
|
|
462
|
+
>>>
|
|
463
|
+
>>> # Case 2: function uses functional sharding
|
|
224
464
|
>>> def test_shard(x, y):
|
|
225
465
|
... return x + y
|
|
226
466
|
>>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
227
467
|
>>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
228
|
-
>>> output = ms.shard(test_shard, in_strategy=((
|
|
468
|
+
>>> output = ms.shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
|
|
229
469
|
>>> print(output.shape)
|
|
230
470
|
(32, 10)
|
|
231
471
|
|
|
232
472
|
Tutorial Examples:
|
|
233
473
|
- `Functional Operator Sharding
|
|
234
|
-
<https://www.mindspore.cn/
|
|
474
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/shard_function_parallel.html>`_
|
|
475
|
+
- `mindspore.Layout
|
|
476
|
+
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Layout.html>`_
|
|
235
477
|
"""
|
|
236
478
|
if not isinstance(fn, (ms.nn.Cell)):
|
|
237
479
|
logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
|