mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Checkpoint related classes and functions."""
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from mindspore.train.serialization import save_checkpoint
|
|
19
|
+
from mindspore.parallel._utils import _get_device_num
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
|
+
from mindspore.train.callback._callback import Callback
|
|
22
|
+
from mindspore import context
|
|
23
|
+
from mindspore.common.parameter import Parameter
|
|
24
|
+
from mindspore.communication import get_rank, get_group_size
|
|
25
|
+
from mindspore import log as logger
|
|
26
|
+
from mindspore.train.serialization import _get_cur_rank_dp
|
|
27
|
+
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post
|
|
28
|
+
from mindspore._c_expression import clean_tdt_channel
|
|
29
|
+
from mindspore._c_expression import send_recv
|
|
30
|
+
from mindspore._c_expression import CollectiveManager
|
|
31
|
+
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
32
|
+
|
|
33
|
+
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
34
|
+
""" Common func to generate ckpt dir name."""
|
|
35
|
+
tmp = "_tmp" if is_tmp_file else ""
|
|
36
|
+
mid_dir = f"tft_saved_checkpoints-step_{str(step)}{tmp}"
|
|
37
|
+
return os.path.join(ckpt_save_path, mid_dir)
|
|
38
|
+
|
|
39
|
+
def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
40
|
+
""" Callback used for TFT save ckpt function when errors occur."""
|
|
41
|
+
logger.info("Enter _save_checkpoint_on_failure function")
|
|
42
|
+
ckpt_save_path = cb_ctx.ckpt_save_path
|
|
43
|
+
cb_params = args
|
|
44
|
+
cur_rank = get_rank()
|
|
45
|
+
cur_step_num = cb_params.cur_step_num
|
|
46
|
+
cur_epoch_num = cb_params.cur_epoch_num
|
|
47
|
+
batch_num = cb_params.batch_num
|
|
48
|
+
if cur_step_num > step:
|
|
49
|
+
cur_epoch_num = (step - 1) // batch_num + 1
|
|
50
|
+
step_num_in_epoch = int((step - 1) % batch_num + 1)
|
|
51
|
+
|
|
52
|
+
append_dict = {}
|
|
53
|
+
append_dict["epoch_num"] = cur_epoch_num
|
|
54
|
+
append_dict["step_num"] = step
|
|
55
|
+
append_dict["cur_rank"] = cur_rank
|
|
56
|
+
append_dict["batch_num"] = batch_num
|
|
57
|
+
append_dict["__exception_save__"] = True
|
|
58
|
+
|
|
59
|
+
append_dict["global_step"] = Parameter([cb_ctx.global_step])
|
|
60
|
+
outputs = cb_params.net_outputs
|
|
61
|
+
if isinstance(outputs, (tuple, list)) and len(outputs) >= 3:
|
|
62
|
+
append_dict["loss_scale"] = outputs[2]
|
|
63
|
+
|
|
64
|
+
ckpt_file = f"ttp_rank_{str(cur_rank)}-{str(cur_epoch_num)}_{str(step_num_in_epoch)}.ckpt"
|
|
65
|
+
cur_ckpt_dir = _get_ckpt_dir(step, ckpt_save_path, True) + "/rank_" + str(cur_rank)
|
|
66
|
+
os.makedirs(cur_ckpt_dir, exist_ok=True)
|
|
67
|
+
cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
|
|
68
|
+
save_checkpoint(cb_params.train_network, cur_file,
|
|
69
|
+
integrated_save=False, append_dict=append_dict)
|
|
70
|
+
logger.info("Finish _save_checkpoint_on_failure function")
|
|
71
|
+
|
|
72
|
+
def _rename_save_result(step, cb_ctx):
|
|
73
|
+
""" Callback used for TFT rename function after ckpt save callback was finished and successful."""
|
|
74
|
+
logger.info("Enter _rename_save_result function")
|
|
75
|
+
tmp_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, True)
|
|
76
|
+
fin_dir = _get_ckpt_dir(step, cb_ctx.ckpt_save_path, False)
|
|
77
|
+
|
|
78
|
+
os.rename(tmp_dir, fin_dir)
|
|
79
|
+
logger.info("Finish _rename_save_result function")
|
|
80
|
+
|
|
81
|
+
def _tft_exit_cb(ctx):
|
|
82
|
+
logger.error("Enter mindio ttp exit process, which means other ranks occur exception, check other ranks' logs!")
|
|
83
|
+
_tft_sem_post()
|
|
84
|
+
os._exit(1) # pylint: disable=W0212
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
|
|
89
|
+
""" Callback used for TFT repair function."""
|
|
90
|
+
logger.info("Enter _tft_repair_callback repair type: {}".format(repair_info["repair_type"]))
|
|
91
|
+
if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
|
|
92
|
+
or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value):
|
|
93
|
+
logger.info("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
|
|
94
|
+
_repair_device(cb_ctx.device_id)
|
|
95
|
+
|
|
96
|
+
if(repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value\
|
|
97
|
+
or repair_info["repair_type"] == cb_ctx.tft.RepairType.RT_SEND.value):
|
|
98
|
+
logger.info("Enter _tft_repair_callback SEND_RECV repair type: \
|
|
99
|
+
{}, src_rank:{}, dst_rank: {}".format(repair_info["repair_type"], repair_info["src"], repair_info["dst"]))
|
|
100
|
+
cb_params = args
|
|
101
|
+
src_rank = repair_info["src"][0]
|
|
102
|
+
dst_rank = repair_info["dst"][0]
|
|
103
|
+
send_recv(cb_params.network.trainable_params(), src_rank, dst_rank)
|
|
104
|
+
logger.info("Finish _tft_repair_callback")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _tft_clean_callback(is_uce_error, ctx):
|
|
108
|
+
""" Callback used for TFT clean function."""
|
|
109
|
+
logger.info("Enter _tft_clean_callback")
|
|
110
|
+
ret = 0
|
|
111
|
+
if is_uce_error:
|
|
112
|
+
_get_uce_mem_info(ctx.device_id)
|
|
113
|
+
err_strategy = _get_uce_process_strategy()
|
|
114
|
+
logger.info("_tft_clean_callback err_strategy: {}".format(err_strategy))
|
|
115
|
+
if err_strategy == "RS_UCE_HIGHLEVEL":
|
|
116
|
+
ret = 0
|
|
117
|
+
elif err_strategy == "RS_UCE_LOWLEVEL":
|
|
118
|
+
ret = 2
|
|
119
|
+
else:
|
|
120
|
+
ret = 1
|
|
121
|
+
clean_tdt_channel()
|
|
122
|
+
logger.info("Enter _tft_clean_callback resume_hccl_comm")
|
|
123
|
+
CollectiveManager.get_instance().resume_hccl_comm()
|
|
124
|
+
logger.info("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
125
|
+
return ret
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _tft_stop_callback(cb_ctx):
|
|
129
|
+
""" Callback used for TFT stop function."""
|
|
130
|
+
logger.info("Enter _tft_stop_callback device_id: {}".format(cb_ctx.device_id))
|
|
131
|
+
_stop_device(cb_ctx.device_id)
|
|
132
|
+
logger.info("Finish _tft_stop_callback")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class TFTRegister(Callback):
|
|
136
|
+
"""
|
|
137
|
+
This callback is used to enable the TFT feature
|
|
138
|
+
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/mindio/mindiottp/mindiottp001.html>`_.
|
|
139
|
+
This callback will execute TFT operations during training process, such as TFT init, report and exception handle.
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
Required for Ascend graph mode only. And sink size must be less than or equal to 1.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
ctrl_rank_id (int): TFT controller's running rank_id, used for init TFT controller.
|
|
146
|
+
ctrl_ip (str): TFT controller's ip address, used for init TFT controller.
|
|
147
|
+
ctrl_port (int): TFT controller's ip port, used for init TFT controller and processor.
|
|
148
|
+
ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
|
|
149
|
+
named ttp_saved_checkpoints-step_{cur_step_num} under this directory.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
Exception: TFT init failed.
|
|
153
|
+
ModuleNotFoundError: Mindio TFT whl package is not installed.
|
|
154
|
+
|
|
155
|
+
Examples:
|
|
156
|
+
>>> import numpy as np
|
|
157
|
+
>>> import os
|
|
158
|
+
>>> import math
|
|
159
|
+
>>> import mindspore as ms
|
|
160
|
+
>>> import mindspore.dataset as ds
|
|
161
|
+
>>> from mindspore import nn, ops, Parameter, train
|
|
162
|
+
>>> from mindspore.communication import init
|
|
163
|
+
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
164
|
+
>>> from mindspore.train import Model, TFTRegister
|
|
165
|
+
>>> from mindspore import dataset as ds
|
|
166
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
|
|
167
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
|
|
168
|
+
>>> init()
|
|
169
|
+
>>> ms.set_seed(1)
|
|
170
|
+
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
|
|
171
|
+
>>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
|
|
172
|
+
>>> class MatMulCell(nn.Cell):
|
|
173
|
+
... def __init__(self, param=None, shape=None):
|
|
174
|
+
... super().__init__()
|
|
175
|
+
... if shape is None:
|
|
176
|
+
... shape = [28 * 28, 512]
|
|
177
|
+
... weight_init = HeUniform(math.sqrt(5))
|
|
178
|
+
... self.param = Parameter(initializer(weight_init, shape), name="param")
|
|
179
|
+
... if param is not None:
|
|
180
|
+
... self.param = param
|
|
181
|
+
... self.print = ops.Print()
|
|
182
|
+
... self.matmul = ops.MatMul()
|
|
183
|
+
...
|
|
184
|
+
... def construct(self, x):
|
|
185
|
+
... out = self.matmul(x, self.param)
|
|
186
|
+
... self.print("out is:", out)
|
|
187
|
+
... return out
|
|
188
|
+
>>>
|
|
189
|
+
>>> class Network(nn.Cell):
|
|
190
|
+
... def __init__(self):
|
|
191
|
+
... super().__init__()
|
|
192
|
+
... self.flatten = nn.Flatten()
|
|
193
|
+
... self.layer1 = MatMulCell()
|
|
194
|
+
... self.relu1 = nn.ReLU()
|
|
195
|
+
... self.layer2 = nn.Dense(512, 512)
|
|
196
|
+
... self.relu2 = nn.ReLU()
|
|
197
|
+
... self.layer3 = nn.Dense(512, 10)
|
|
198
|
+
...
|
|
199
|
+
... def construct(self, x):
|
|
200
|
+
... x = self.flatten(x)
|
|
201
|
+
... x = self.layer1(x)
|
|
202
|
+
... x = self.relu1(x)
|
|
203
|
+
... x = self.layer2(x)
|
|
204
|
+
... x = self.relu2(x)
|
|
205
|
+
... logits = self.layer3(x)
|
|
206
|
+
... return logits
|
|
207
|
+
>>>
|
|
208
|
+
>>> net = Network()
|
|
209
|
+
>>> net.layer1.pipeline_stage = 0
|
|
210
|
+
>>> net.relu1.pipeline_stage = 0
|
|
211
|
+
>>> net.layer2.pipeline_stage = 0
|
|
212
|
+
>>> net.relu2.pipeline_stage = 1
|
|
213
|
+
>>> net.layer3.pipeline_stage = 1
|
|
214
|
+
>>>
|
|
215
|
+
>>> def create_dataset(batch_size):
|
|
216
|
+
... dataset_path = os.getenv("DATA_PATH")
|
|
217
|
+
... dataset = ds.MnistDataset(dataset_path)
|
|
218
|
+
... image_transforms = [
|
|
219
|
+
... ds.vision.Rescale(1.0 / 255.0, 0),
|
|
220
|
+
... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
|
|
221
|
+
... ds.vision.HWC2CHW()
|
|
222
|
+
... ]
|
|
223
|
+
... label_transform = ds.transforms.TypeCast(ms.int32)
|
|
224
|
+
... dataset = dataset.map(image_transforms, 'image')
|
|
225
|
+
... dataset = dataset.map(label_transform, 'label')
|
|
226
|
+
... dataset = dataset.batch(batch_size)
|
|
227
|
+
... return dataset
|
|
228
|
+
>>>
|
|
229
|
+
>>> data_set = create_dataset(32)
|
|
230
|
+
>>>
|
|
231
|
+
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
|
|
232
|
+
>>> optimizer_wrapper = nn.OptTFTWrapper(optimizer)
|
|
233
|
+
>>> loss_fn = nn.CrossEntropyLoss()
|
|
234
|
+
>>>
|
|
235
|
+
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
|
|
236
|
+
>>> net_with_loss.set_train()
|
|
237
|
+
>>> model = Model(net_with_loss, optimizer=optimizer)
|
|
238
|
+
>>> tft_cb = TFTRegister("192.168.0.1", 2000, "./tft_checkpoint/")
|
|
239
|
+
>>> loss_cb = train.LossMonitor(1)
|
|
240
|
+
>>> model.train(1, dataset, callbacks=[tft_cb, loss_cb])
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(self, ctrl_rank_id, ctrl_ip, ctrl_port, ckpt_save_path):
|
|
244
|
+
super(TFTRegister, self).__init__()
|
|
245
|
+
|
|
246
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
247
|
+
if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
|
|
248
|
+
raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
|
|
249
|
+
mode = context.get_context("mode")
|
|
250
|
+
device_target = context.get_context("device_target")
|
|
251
|
+
if device_target != "Ascend" or mode != context.GRAPH_MODE:
|
|
252
|
+
raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
|
|
253
|
+
|
|
254
|
+
# let it raise errors if not install mindio_tft package
|
|
255
|
+
from mindio_ttp import framework_ttp as tft
|
|
256
|
+
self.tft = tft
|
|
257
|
+
self.global_step = 0
|
|
258
|
+
Validator.check_non_negative_int(ctrl_port)
|
|
259
|
+
self.has_init_replica = False
|
|
260
|
+
self._controller_ip = ctrl_ip
|
|
261
|
+
self._controller_rank_id = ctrl_rank_id
|
|
262
|
+
self._controller_port = ctrl_port
|
|
263
|
+
self.device_id = context.get_context("device_id")
|
|
264
|
+
self._init_tft()
|
|
265
|
+
self.ckpt_save_path = ckpt_save_path
|
|
266
|
+
|
|
267
|
+
def _set_tft_optimizer_replica(self, run_context):
|
|
268
|
+
""" set Mindio TFT optimizer replica info, used internal. """
|
|
269
|
+
cur_rank = get_rank()
|
|
270
|
+
cb_params = run_context.original_args()
|
|
271
|
+
train_network = cb_params.train_network
|
|
272
|
+
# in data_parallel mode, every ranks has same train parameters
|
|
273
|
+
if context.get_auto_parallel_context("parallel_mode") == "data_parallel":
|
|
274
|
+
group_size = get_group_size()
|
|
275
|
+
dp = tuple(range(group_size))
|
|
276
|
+
else:
|
|
277
|
+
param_layout_dict = train_network.parameter_layout_dict
|
|
278
|
+
dp = _get_cur_rank_dp(param_layout_dict) if param_layout_dict else _get_cur_rank_dp(train_network)
|
|
279
|
+
logger.warning(f"Set TFT replica with dp: {dp}.")
|
|
280
|
+
replica_info = [
|
|
281
|
+
{
|
|
282
|
+
"type": 1,
|
|
283
|
+
"rank_list": dp,
|
|
284
|
+
"replica_cnt": len(dp),
|
|
285
|
+
"replica_shift": 0
|
|
286
|
+
}
|
|
287
|
+
]
|
|
288
|
+
self.tft.tft_set_optimizer_replica(cur_rank, replica_info)
|
|
289
|
+
|
|
290
|
+
def _init_tft(self):
|
|
291
|
+
""" Init Mindio TFT, used internal. """
|
|
292
|
+
logger.info("Begin to init tft.")
|
|
293
|
+
self.tft.tft_register_save_ckpt_handler(_save_checkpoint_on_failure, self)
|
|
294
|
+
self.tft.tft_register_rename_handler(_rename_save_result, self)
|
|
295
|
+
self.tft.tft_register_exit_handler(_tft_exit_cb, self)
|
|
296
|
+
self.tft.tft_register_stop_handler(_tft_stop_callback, self)
|
|
297
|
+
self.tft.tft_register_clean_handler(_tft_clean_callback, self)
|
|
298
|
+
self.tft.tft_register_repair_handler(_tft_repair_callback, self)
|
|
299
|
+
|
|
300
|
+
world_size = _get_device_num()
|
|
301
|
+
cur_rank = get_rank()
|
|
302
|
+
enable_local_copy = False
|
|
303
|
+
enable_arf = False
|
|
304
|
+
enable_zit = False
|
|
305
|
+
enable_tls = False
|
|
306
|
+
tls_key_dir = ""
|
|
307
|
+
|
|
308
|
+
if cur_rank == self._controller_rank_id:
|
|
309
|
+
logger.info(f"Begin to start tft controller on rank_id:{cur_rank}")
|
|
310
|
+
self.tft.tft_init_controller(cur_rank, world_size, enable_local_copy, enable_arf, enable_zit)
|
|
311
|
+
self.tft.tft_start_controller(self._controller_ip, self._controller_port, enable_tls, tls_key_dir)
|
|
312
|
+
logger.info("Finish start tft controller.")
|
|
313
|
+
|
|
314
|
+
logger.info("Begin to start tft processor.")
|
|
315
|
+
self.tft.tft_init_processor(cur_rank, world_size, enable_local_copy, enable_tls, tls_key_dir)
|
|
316
|
+
self.tft.tft_start_processor(self._controller_ip, self._controller_port)
|
|
317
|
+
logger.info("Finished start tft processor.")
|
|
318
|
+
|
|
319
|
+
def on_train_step_end(self, run_context):
|
|
320
|
+
"""
|
|
321
|
+
And report status to MindIO TFT after every step finished.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
325
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
326
|
+
"""
|
|
327
|
+
if self.has_init_replica is False:
|
|
328
|
+
self.has_init_replica = True
|
|
329
|
+
self._set_tft_optimizer_replica(run_context)
|
|
330
|
+
cb_params = run_context.original_args()
|
|
331
|
+
if cb_params.optimizer is not None:
|
|
332
|
+
self.global_step = int(cb_params.optimizer.global_step.data)
|
|
333
|
+
else:
|
|
334
|
+
self.global_step = int(cb_params.network.optimizer.global_step.data)
|
|
335
|
+
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
|
|
336
|
+
self.tft.tft_end_updating_os(cb_params.cur_step_num)
|
|
337
|
+
logger.info("END Set optimizer finish step status to TFT.")
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def on_train_begin(self, run_context):
|
|
341
|
+
cb_params = run_context.original_args()
|
|
342
|
+
sink_size = cb_params.get("sink_size", 0)
|
|
343
|
+
if sink_size > 1:
|
|
344
|
+
raise ValueError("TFT feature doesn't support sink_size > 1.")
|
|
345
|
+
logger.info("Set set args to TFT.")
|
|
346
|
+
self.tft.tft_set_step_args(cb_params)
|
|
347
|
+
|
|
348
|
+
def end(self, run_context):
|
|
349
|
+
cur_rank = get_rank()
|
|
350
|
+
if cur_rank == self._controller_rank_id:
|
|
351
|
+
self.tft.tft_destroy_controller()
|
|
352
|
+
self.tft.tft_destroy_processor()
|
|
@@ -30,7 +30,7 @@ class TimeMonitor(Callback):
|
|
|
30
30
|
if the program get `batch_num` during training, `data_size` will be set to `batch_num`,
|
|
31
31
|
otherwise `data_size` will be used. Default: ``None`` .
|
|
32
32
|
|
|
33
|
-
data_time (bool): Whether to
|
|
33
|
+
data_time (bool): Whether to show the average time of fetching data in Host.
|
|
34
34
|
Note that data fetch and network compute are processed sequentially in non dataset sink mode, while
|
|
35
35
|
they are asynchronous in dataset sink mode. Default: ``False`` .
|
|
36
36
|
|
|
@@ -43,13 +43,13 @@ class TimeMonitor(Callback):
|
|
|
43
43
|
>>> from mindspore.train import Model, TimeMonitor
|
|
44
44
|
>>>
|
|
45
45
|
>>> # Define the network structure of LeNet5. Refer to
|
|
46
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
46
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
47
47
|
>>> net = LeNet5()
|
|
48
48
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
49
49
|
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
|
50
50
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
|
51
51
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
52
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
52
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
53
53
|
>>> dataset = create_dataset()
|
|
54
54
|
>>> time_monitor = TimeMonitor()
|
|
55
55
|
>>> model.train(10, dataset, callbacks=time_monitor)
|
mindspore/train/data_sink.py
CHANGED
|
@@ -130,10 +130,11 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
|
|
|
130
130
|
A wrapper function to generate a function for the input function.
|
|
131
131
|
|
|
132
132
|
Note:
|
|
133
|
-
When using data sinking, the dataset will be automatically
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
133
|
+
When using data sinking, the dataset will be automatically looped to the device. The device side can cache up
|
|
134
|
+
to 100 batches of data and occupy no more than 2GB of memory. At this time, only the number of steps for each
|
|
135
|
+
sinking `sink_size` needs to be considered. `sink_size` defaults to ``1``, indicating that each epoch only
|
|
136
|
+
takes one batch of data from the cache for training and outputs a loss. If `sink_size` is greater than 1, each
|
|
137
|
+
epoch takes out `sink_size` batches of data from the cache for training and outputs a loss.
|
|
137
138
|
|
|
138
139
|
Args:
|
|
139
140
|
fn (Function): The Python function that will be run with dataset.
|
|
@@ -145,7 +146,7 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
|
|
|
145
146
|
input_signature (Union[Tensor, List or Tuple of Tensors]): The Tensor which describes the input arguments.
|
|
146
147
|
The shape and dtype of the Tensor will be supplied to this function. If input_signature is specified,
|
|
147
148
|
each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept `**kwargs`. The shape
|
|
148
|
-
and dtype of actual inputs should keep the same as input_signature
|
|
149
|
+
and dtype of actual inputs should keep the same as `input_signature`. Otherwise, TypeError will be raised.
|
|
149
150
|
Default: ``None`` .
|
|
150
151
|
|
|
151
152
|
Returns:
|
|
@@ -16,18 +16,20 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
import math
|
|
19
|
+
import copy
|
|
19
20
|
|
|
20
21
|
from mindspore import _checkparam as Validator
|
|
21
22
|
from mindspore import log as logger
|
|
22
23
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
|
|
23
24
|
from mindspore.common.dtype import pytype_to_dtype
|
|
24
|
-
from mindspore.common.api import _cell_graph_executor
|
|
25
|
+
from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
|
|
25
26
|
from mindspore.common._utils import is_shape_unknown
|
|
26
27
|
from mindspore.dataset.engine import offload
|
|
27
28
|
from mindspore import context, nn
|
|
28
29
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
|
|
29
30
|
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \
|
|
30
|
-
_to_full_shapes, _get_pipeline_stages
|
|
31
|
+
_to_full_shapes, _get_pipeline_stages, _change_symbols_for_parallel, _is_in_auto_parallel_mode, \
|
|
32
|
+
_origin_shapes, _dynamic_shape_for_dataset
|
|
31
33
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
32
34
|
from mindspore.ops import operations as P
|
|
33
35
|
from mindspore.common.auto_dynamic_shape import _auto_dynamic_shape
|
|
@@ -95,6 +97,20 @@ class _DataWrapper(nn.Cell):
|
|
|
95
97
|
self.add_flags(**flags)
|
|
96
98
|
self.get_next = P.GetNext(
|
|
97
99
|
dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
100
|
+
if network.get_inputs() is not None:
|
|
101
|
+
network_inputs = network.get_inputs()
|
|
102
|
+
is_fullmode = _is_args_fullmode(network_inputs, False)
|
|
103
|
+
if is_fullmode:
|
|
104
|
+
symbol_inputs = [getattr(inp, "symbolic_shape", None) for inp in network.get_inputs()]
|
|
105
|
+
else:
|
|
106
|
+
symbol_inputs = [None for _ in dataset_shapes]
|
|
107
|
+
arg_specified = network_inputs.get(ARG_SPECIFIED, [])
|
|
108
|
+
for idx, inp in arg_specified:
|
|
109
|
+
symbol_inputs[idx] = getattr(inp, "symbolic_shape", None)
|
|
110
|
+
symbols_for_parallel = _change_symbols_for_parallel(dataset_shapes, copy.deepcopy(symbol_inputs))
|
|
111
|
+
if any((s is not None for s in symbols_for_parallel)):
|
|
112
|
+
self.get_next.add_prim_attr("symbols", symbol_inputs)
|
|
113
|
+
self.get_next.add_prim_attr("symbols_for_parallel", symbols_for_parallel)
|
|
98
114
|
self.network = network
|
|
99
115
|
self._get_attr_from_cell(network)
|
|
100
116
|
|
|
@@ -129,7 +145,15 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
|
|
129
145
|
return _generate_dataset_sink_mode_net(network, new_shapes, dataset_types, queue_name)
|
|
130
146
|
|
|
131
147
|
if network.get_inputs() and None not in network.get_inputs():
|
|
132
|
-
|
|
148
|
+
if _is_in_auto_parallel_mode():
|
|
149
|
+
# here, the dataset shapes has been processed by full_shape(), so need to resume it to original shape
|
|
150
|
+
# the _check_inputs() will change static origin_shape to dynamic shape
|
|
151
|
+
# after _check_inputs(), convert dataset_shapes to dynamic shape
|
|
152
|
+
origin_shape = _origin_shapes(dataset_shapes)
|
|
153
|
+
_check_inputs(network.get_inputs(), origin_shape, dataset_types)
|
|
154
|
+
dataset_shapes = _dynamic_shape_for_dataset(dataset_shapes, origin_shape)
|
|
155
|
+
else:
|
|
156
|
+
_check_inputs(network.get_inputs(), dataset_shapes, dataset_types)
|
|
133
157
|
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
|
134
158
|
dataset_shapes = tuple([(-2,)] * len(dataset_shapes))
|
|
135
159
|
network = _generate_dataset_sink_mode_net(
|
|
@@ -141,6 +165,13 @@ def _check_inputs(network_shapes, dataset_shapes, dataset_types):
|
|
|
141
165
|
"""
|
|
142
166
|
Check if set inputs are correct.
|
|
143
167
|
"""
|
|
168
|
+
if not _is_args_fullmode(network_shapes, False):
|
|
169
|
+
temp_network_shapes = [None for _ in dataset_shapes]
|
|
170
|
+
arg_specified = network_shapes.get(ARG_SPECIFIED, [])
|
|
171
|
+
for idx, inp in arg_specified:
|
|
172
|
+
temp_network_shapes[idx] = inp
|
|
173
|
+
network_shapes = temp_network_shapes
|
|
174
|
+
|
|
144
175
|
for tensor_index, ele_dataset_shape in enumerate(dataset_shapes):
|
|
145
176
|
if network_shapes[tensor_index] is None:
|
|
146
177
|
continue
|
|
@@ -182,7 +213,8 @@ def _get_dataset_aux(dataset):
|
|
|
182
213
|
def connect_network_with_dataset(network, dataset_helper):
|
|
183
214
|
"""
|
|
184
215
|
Connect the `network` with dataset in `dataset_helper`. Only supported in `sink mode
|
|
185
|
-
<https://mindspore.cn/
|
|
216
|
+
<https://mindspore.cn/docs/en/master/model_train/train_process/train_optimize.html>`_,
|
|
217
|
+
(dataset_sink_mode=True).
|
|
186
218
|
|
|
187
219
|
Args:
|
|
188
220
|
network (Cell): The training network for dataset.
|
|
@@ -230,29 +262,36 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
230
262
|
"The dataset has been connected to other network, please check the code.")
|
|
231
263
|
is_dynamic = bool(network.get_inputs())
|
|
232
264
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
233
|
-
|
|
265
|
+
# In pipeline parallel, some stages have no GetNext, should not get in.
|
|
266
|
+
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
267
|
+
if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel:
|
|
234
268
|
dataset_types, dataset_shapes = dataset_helper.get_data_info()
|
|
269
|
+
# Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
|
|
270
|
+
if _need_to_full():
|
|
271
|
+
dataset_shapes = _to_full_shapes(dataset_shapes, _get_device_num() // _get_pipeline_stages())
|
|
235
272
|
dataset_types = [pytype_to_dtype(x) for x in dataset_types]
|
|
236
273
|
if not is_dynamic:
|
|
237
274
|
dataset_shapes = _auto_dynamic_shape.auto_dynamic_generate_compile_args(dataset_shapes, True)
|
|
238
275
|
key = str(dataset_types) + str(dataset_shapes)
|
|
239
|
-
|
|
240
|
-
if hasattr(aux,
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
device_num = _get_device_num() // _get_pipeline_stages()
|
|
245
|
-
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
246
|
-
|
|
247
|
-
network = _generate_dataset_sink_mode_net(
|
|
248
|
-
network, dataset_shapes, dataset_types, queue_name)
|
|
249
|
-
if hasattr(aux, '__network_manage__'):
|
|
250
|
-
aux.__network_manage__ = aux.__network_manage__
|
|
276
|
+
|
|
277
|
+
if hasattr(aux, "__shape_type__") and aux.__shape_type__ != key:
|
|
278
|
+
_auto_dynamic_shape.update_phase_and_compile_args(dataset_shapes, key, True, aux)
|
|
279
|
+
if hasattr(aux, '__network_manage__') and key in aux.__network_manage__:
|
|
280
|
+
network = aux.__network_manage__[key]
|
|
251
281
|
else:
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
282
|
+
if _need_to_full():
|
|
283
|
+
device_num = _get_device_num() // _get_pipeline_stages()
|
|
284
|
+
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
285
|
+
|
|
286
|
+
network = _generate_dataset_sink_mode_net(
|
|
287
|
+
network, dataset_shapes, dataset_types, queue_name)
|
|
288
|
+
if hasattr(aux, '__network_manage__'):
|
|
289
|
+
aux.__network_manage__ = aux.__network_manage__
|
|
290
|
+
else:
|
|
291
|
+
aux.__network_manage__ = dict()
|
|
292
|
+
aux.__network_manage__[key] = network
|
|
293
|
+
network.add_flags(sink_mode=True)
|
|
294
|
+
return network
|
|
256
295
|
|
|
257
296
|
if hasattr(aux, '__sink_network__'):
|
|
258
297
|
network = aux.__sink_network__
|
|
@@ -263,8 +302,11 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
263
302
|
network = _generate_network_with_dataset(
|
|
264
303
|
network, dataset_helper, queue_name)
|
|
265
304
|
aux.__sink_network__ = network
|
|
305
|
+
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
|
306
|
+
aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
|
|
266
307
|
|
|
267
|
-
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic)
|
|
308
|
+
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
|
|
309
|
+
not use_pipeline_parallel:
|
|
268
310
|
dataset_helper.get_data_info()
|
|
269
311
|
network.add_flags(sink_mode=True)
|
|
270
312
|
return network
|
|
@@ -282,7 +324,7 @@ class DatasetHelper:
|
|
|
282
324
|
|
|
283
325
|
Args:
|
|
284
326
|
dataset (Dataset): The dataset iterator. The dataset can be generated by dataset generator API in
|
|
285
|
-
|
|
327
|
+
`mindspore.dataset` module, such as :class:`mindspore.dataset.ImageFolderDataset`.
|
|
286
328
|
dataset_sink_mode (bool): If the value is True, GetNext is employed to fetch the data at device through the
|
|
287
329
|
dataset pipeline, otherwise fetch the data at host by iterating through the dataset.
|
|
288
330
|
Default: ``True``.
|
|
@@ -458,6 +500,7 @@ class DatasetHelper:
|
|
|
458
500
|
'''
|
|
459
501
|
Inner class for parsing send info.
|
|
460
502
|
'''
|
|
503
|
+
|
|
461
504
|
def __init__(self, send_info, run_context):
|
|
462
505
|
self.info_ = {}
|
|
463
506
|
self.sink_size = run_context.original_args()["batch_num"]
|
|
@@ -62,7 +62,7 @@ class FixedLossScaleManager(LossScaleManager):
|
|
|
62
62
|
>>> from mindspore import amp, nn
|
|
63
63
|
>>>
|
|
64
64
|
>>> # Define the network structure of LeNet5. Refer to
|
|
65
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
65
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
66
66
|
>>> net = LeNet5()
|
|
67
67
|
>>> loss_scale = 1024.0
|
|
68
68
|
>>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False)
|
|
@@ -136,7 +136,7 @@ class DynamicLossScaleManager(LossScaleManager):
|
|
|
136
136
|
>>> from mindspore import amp, nn
|
|
137
137
|
>>>
|
|
138
138
|
>>> # Define the network structure of LeNet5. Refer to
|
|
139
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
139
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
140
140
|
>>> net = LeNet5()
|
|
141
141
|
>>> loss_scale_manager = amp.DynamicLossScaleManager()
|
|
142
142
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
@@ -32,9 +32,9 @@ class Accuracy(EvaluationBase):
|
|
|
32
32
|
{\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
|
-
eval_type (str): The metric to calculate the accuracy over a dataset. Supports 'classification' and
|
|
36
|
-
'multilabel'
|
|
37
|
-
labels. Default: ``'classification'`` .
|
|
35
|
+
eval_type (str): The metric to calculate the accuracy over a dataset. Supports ``'classification'`` and
|
|
36
|
+
``'multilabel'``. ``'classification'`` means the dataset label is single.
|
|
37
|
+
``'multilabel'`` means the dataset has multiple labels. Default: ``'classification'`` .
|
|
38
38
|
|
|
39
39
|
Supported Platforms:
|
|
40
40
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -67,19 +67,19 @@ class Accuracy(EvaluationBase):
|
|
|
67
67
|
@rearrange_inputs
|
|
68
68
|
def update(self, *inputs):
|
|
69
69
|
"""
|
|
70
|
-
Updates the local variables. For 'classification'
|
|
71
|
-
matches the label, the predict result is correct. For 'multilabel'
|
|
70
|
+
Updates the local variables. For ``'classification'``, if the index of the maximum of the predict value
|
|
71
|
+
matches the label, the predict result is correct. For ``'multilabel'``, the predict value match the label,
|
|
72
72
|
the predict result is correct.
|
|
73
73
|
|
|
74
74
|
Args:
|
|
75
75
|
inputs: Logits and labels. `y_pred` stands for logits, `y` stands for labels. `y_pred` and `y` must be a
|
|
76
76
|
`Tensor`, a list or an array.
|
|
77
77
|
|
|
78
|
-
- For the 'classification' evaluation type, `y_pred` is a list of floating numbers in range
|
|
78
|
+
- For the ``'classification'`` evaluation type, `y_pred` is a list of floating numbers in range
|
|
79
79
|
:math:`[0, 1]` and the shape is :math:`(N, C)` in most cases (not strictly), where :math:`N`
|
|
80
80
|
is the number of cases and :math:`C` is the number of categories. `y` must be in one-hot format
|
|
81
81
|
that shape is :math:`(N, C)`, or can be transformed to one-hot format that shape is :math:`(N,)`.
|
|
82
|
-
- For 'multilabel' evaluation type, the value of `y_pred` and `y` can only be 0 or 1, indices with 1
|
|
82
|
+
- For ``'multilabel'`` evaluation type, the value of `y_pred` and `y` can only be 0 or 1, indices with 1
|
|
83
83
|
indicate the positive category. The shape of `y_pred` and `y` are both :math:`(N, C)`.
|
|
84
84
|
|
|
85
85
|
Raises:
|
|
@@ -33,10 +33,10 @@ class ConfusionMatrix(Metric):
|
|
|
33
33
|
num_classes (int): Number of classes in the dataset.
|
|
34
34
|
normalize (str): Normalization mode for confusion matrix. Default: ``"no_norm"`` . Choose from:
|
|
35
35
|
|
|
36
|
-
-
|
|
37
|
-
-
|
|
38
|
-
-
|
|
39
|
-
-
|
|
36
|
+
- ``"no_norm"`` : No Normalization is used. Default: ``None``.
|
|
37
|
+
- ``"target"`` : Normalization based on target value.
|
|
38
|
+
- ``"prediction"`` : Normalization based on predicted value.
|
|
39
|
+
- ``"all"`` : Normalization over the whole matrix.
|
|
40
40
|
|
|
41
41
|
threshold (float): The threshold used to compare with the input tensor. Default: ``0.5`` .
|
|
42
42
|
|
|
@@ -224,8 +224,10 @@ class ConfusionMatrixMetric(Metric):
|
|
|
224
224
|
|
|
225
225
|
- y_pred (ndarray): The batch data shape is :math:`(N, C, ...)`
|
|
226
226
|
or :math:`(N, ...)`, representing onehot format
|
|
227
|
-
or category index format respectively. As for classification tasks,
|
|
228
|
-
|
|
227
|
+
or category index format respectively. As for classification tasks,
|
|
228
|
+
y_pred should have the shape :math:`(B, N)`
|
|
229
|
+
where N is larger than 1. As for segmentation tasks,
|
|
230
|
+
the shape should be :math:`(B, N, H, W)` or :math:`(B, N, H, W, D)`.
|
|
229
231
|
- y (ndarray): It must be one-hot format. The batch data shape is :math:`(N, C, ...)`.
|
|
230
232
|
|
|
231
233
|
Raises:
|