mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.3.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +76 -18
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +258 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +174 -62
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +142 -240
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +51 -24
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +15 -4
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +8 -9
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +411 -106
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +6 -8
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +260 -0
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +30 -11
- mindspore/common/recompute.py +262 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +272 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +468 -494
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +1140 -0
- mindspore/communication/management.py +115 -102
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +346 -63
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +140 -83
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +78 -59
- mindspore/dataset/engine/datasets_vision.py +77 -73
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +56 -38
- mindspore/dataset/engine/validators.py +11 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +8 -8
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +14 -2
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/dataset/vision.h +54 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +76 -58
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/mint/__init__.py +1137 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +512 -0
- mindspore/mint/nn/functional.py +573 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +185 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +1 -0
- mindspore/nn/cell.py +213 -257
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +109 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/layer/activation.py +83 -93
- mindspore/nn/layer/basic.py +177 -82
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +101 -43
- mindspore/nn/layer/embedding_service.py +531 -0
- mindspore/nn/layer/embedding_service_layer.py +393 -0
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +18 -9
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +62 -83
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +58 -13
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +32 -9
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +6 -6
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +62 -68
- mindspore/numpy/utils.py +3 -0
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +89 -34
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +164 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +130 -58
- mindspore/ops/_vmap/vmap_nn_ops.py +249 -115
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +980 -0
- mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +121 -23
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +191 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +53 -0
- mindspore/ops/extend/array_func.py +218 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +308 -0
- mindspore/ops/function/__init__.py +31 -11
- mindspore/ops/function/array_func.py +846 -1735
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +27 -20
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +913 -2791
- mindspore/ops/function/nn_func.py +1439 -885
- mindspore/ops/function/other_func.py +6 -7
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +254 -108
- mindspore/ops/function/reshard_func.py +102 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +342 -343
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +32 -23
- mindspore/ops/operations/_grad_ops.py +21 -853
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +107 -518
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +108 -2705
- mindspore/ops/operations/comm_ops.py +801 -118
- mindspore/ops/operations/custom_ops.py +61 -120
- mindspore/ops/operations/debug_ops.py +104 -35
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
- mindspore/ops/operations/math_ops.py +572 -4667
- mindspore/ops/operations/nn_ops.py +248 -2162
- mindspore/ops/operations/other_ops.py +53 -45
- mindspore/ops/operations/random_ops.py +4 -53
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +204 -103
- mindspore/ops/silent_check.py +5 -5
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +250 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_ops.py +1084 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +968 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +354 -0
- mindspore/ops_generate/template.py +239 -0
- mindspore/parallel/__init__.py +6 -4
- mindspore/parallel/_auto_parallel_context.py +73 -3
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +29 -13
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +18 -11
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +161 -6
- mindspore/parallel/algo_parameter_config.py +6 -8
- mindspore/parallel/checkpoint_transform.py +191 -32
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +344 -0
- mindspore/parallel/cluster/process_entity/_utils.py +126 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +128 -17
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +3 -2
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +125 -0
- mindspore/profiler/envprofiling.py +2 -2
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +147 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +296 -133
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +445 -190
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -5
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +143 -29
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +238 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_mindio_ttp.py +443 -0
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +60 -21
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +290 -60
- mindspore/train/serialization.py +495 -220
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/METADATA +3 -3
- mindspore-2.3.0.dist-info/RECORD +1400 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Checkpoint related classes and functions."""
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import copy
|
|
20
|
+
from mindspore.train.serialization import save_checkpoint, _convert_cell_param_and_names_to_dict, _get_merged_param_data
|
|
21
|
+
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
22
|
+
from mindspore.parallel._utils import _get_device_num
|
|
23
|
+
from mindspore import _checkparam as Validator
|
|
24
|
+
from mindspore.train.callback._callback import Callback
|
|
25
|
+
from mindspore.common.tensor import Tensor
|
|
26
|
+
from mindspore import context
|
|
27
|
+
import mindspore as ms
|
|
28
|
+
from mindspore.communication import get_rank
|
|
29
|
+
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
30
|
+
|
|
31
|
+
from mindspore.train._utils import get_parameter_redundancy
|
|
32
|
+
from mindspore import log as logger
|
|
33
|
+
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
|
34
|
+
from mindspore.common.api import _get_parameter_layout
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _get_dp_from_layout(parameter_layout_dict):
|
|
38
|
+
""" Get dp and tp from layout dict. """
|
|
39
|
+
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
40
|
+
dev_num = _get_device_num()
|
|
41
|
+
global_rank = get_rank()
|
|
42
|
+
pipe_size = dev_num // pp_num
|
|
43
|
+
initial_rank = (global_rank // pipe_size) * pipe_size
|
|
44
|
+
parameter_redundancy_dict = get_parameter_redundancy(
|
|
45
|
+
parameter_layout_dict, initial_rank)
|
|
46
|
+
value_len = sys.maxsize
|
|
47
|
+
min_value = ()
|
|
48
|
+
for key, value in parameter_redundancy_dict.items():
|
|
49
|
+
if "accu_grads" in key or "inputs" in key:
|
|
50
|
+
continue
|
|
51
|
+
for item in value:
|
|
52
|
+
if len(item) < value_len and global_rank in item:
|
|
53
|
+
value_len = len(item)
|
|
54
|
+
min_value = item
|
|
55
|
+
return min_value
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_ckpt_dir(append_dict, ckpt_save_path, is_tmp_file):
|
|
59
|
+
""" Common func to generate ckpt dir name."""
|
|
60
|
+
tmp = "_tmp" if is_tmp_file else ""
|
|
61
|
+
mid_dir = f"ttp_saved_checkpoints-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}{tmp}"
|
|
62
|
+
return os.path.join(ckpt_save_path, mid_dir)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _flush_from_cache(cb_params):
|
|
66
|
+
""" Flush cache data to host if tensor is cache enable."""
|
|
67
|
+
params = cb_params.train_network.get_parameters()
|
|
68
|
+
for param in params:
|
|
69
|
+
if param.cache_enable:
|
|
70
|
+
Tensor(param).flush_from_cache()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _save_checkpoint_on_failure(save_rank, step, rank_list, save_args):
|
|
74
|
+
""" Callback used for TTP save ckpt function when errors occur."""
|
|
75
|
+
logger.info("Enter _save_checkpoint_on_failure function")
|
|
76
|
+
ckpt_save_path, save_params, append_dict = save_args
|
|
77
|
+
ckpt_file = f"iteration-{str(append_dict['cur_epoch_num'])}_{str(append_dict['cur_step_num'])}.ckpt"
|
|
78
|
+
cur_ckpt_dir = _get_ckpt_dir(
|
|
79
|
+
append_dict, ckpt_save_path, True) + "/rank_" + str(save_rank)
|
|
80
|
+
os.makedirs(cur_ckpt_dir)
|
|
81
|
+
cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
|
|
82
|
+
save_checkpoint(save_params, cur_file,
|
|
83
|
+
integrated_save=False, append_dict=append_dict)
|
|
84
|
+
logger.info("Finish _save_checkpoint_on_failure function")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _convert_net_to_param_list(save_obj):
|
|
88
|
+
"""Convert nn.Cell to param_list."""
|
|
89
|
+
sync_pipeline_shared_parameters(save_obj)
|
|
90
|
+
param_list = []
|
|
91
|
+
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
92
|
+
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
93
|
+
parameter_layout_dict = _get_parameter_layout()
|
|
94
|
+
if not _is_in_auto_parallel_mode():
|
|
95
|
+
save_obj.init_parameters_data()
|
|
96
|
+
param_dict = _convert_cell_param_and_names_to_dict(save_obj, None)
|
|
97
|
+
for (key, value) in param_dict.items():
|
|
98
|
+
each_param = {"name": key}
|
|
99
|
+
param_data = Tensor(value.asnumpy())
|
|
100
|
+
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
101
|
+
# which should be combined before saving
|
|
102
|
+
if key in parameter_layout_dict:
|
|
103
|
+
param_data = _get_merged_param_data(
|
|
104
|
+
save_obj, parameter_layout_dict, key, param_data, False)
|
|
105
|
+
each_param["data"] = param_data
|
|
106
|
+
param_list.append(each_param)
|
|
107
|
+
return param_list
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _rename_save_result(rename_args):
|
|
111
|
+
""" Callback used for TTP rename function after ckpt save callback was finished and successful."""
|
|
112
|
+
logger.info("Enter _rename_save_result function")
|
|
113
|
+
ckpt_save_path, _, append_dict = rename_args
|
|
114
|
+
|
|
115
|
+
tmp_dir = _get_ckpt_dir(append_dict, ckpt_save_path, True)
|
|
116
|
+
fin_dir = _get_ckpt_dir(append_dict, ckpt_save_path, False)
|
|
117
|
+
|
|
118
|
+
os.rename(tmp_dir, fin_dir)
|
|
119
|
+
logger.info("Finish _rename_save_result function")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class MindIOTTPAdapter(Callback):
|
|
123
|
+
"""
|
|
124
|
+
This callback is used to enable the feature
|
|
125
|
+
`MindIO TTP <https://www.hiascend.com/document/detail/zh/mindx-dl/60rc1/mindio/mindiottp/mindiottp001.html>`_.
|
|
126
|
+
This callback will execute TTP operations during training process, such as TTP init, report and exception handle.
|
|
127
|
+
|
|
128
|
+
Note:
|
|
129
|
+
Required for Ascend GE LazyInline mode only. And pipline size must be greater than 1.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
controller_ip (str): TTP controller's ip address, used for init TTP controller.
|
|
133
|
+
controller_port (int): TTP controller's ip port, used for init TTP controller and processor.
|
|
134
|
+
ckpt_save_path (str): Checkpoint save directory when failure occurs, checkpoint file will save to directory
|
|
135
|
+
named ttp_saved_checkpoints-{cur_epoch_num}_{cur_step_num} under this directory.
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
Exception: TTP init failed.
|
|
139
|
+
ModuleNotFoundError: Mindio TTP whl package is not installed.
|
|
140
|
+
|
|
141
|
+
Examples:
|
|
142
|
+
>>> import numpy as np
|
|
143
|
+
>>> import os
|
|
144
|
+
>>> import math
|
|
145
|
+
>>> import mindspore as ms
|
|
146
|
+
>>> import mindspore.dataset as ds
|
|
147
|
+
>>> from mindspore import nn, ops, Parameter, train
|
|
148
|
+
>>> from mindspore.communication import init
|
|
149
|
+
>>> from mindspore.common.initializer import initializer, HeUniform
|
|
150
|
+
>>> from mindspore.train import Model, MindIOTTPAdapter
|
|
151
|
+
>>> from mindspore import dataset as ds
|
|
152
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE, jit_level='O2')
|
|
153
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=2)
|
|
154
|
+
>>> init()
|
|
155
|
+
>>> ms.set_seed(1)
|
|
156
|
+
>>> ms.set_auto_parallel_context(strategy_ckpt_config={"save_file":
|
|
157
|
+
>>> "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
|
|
158
|
+
>>> class MatMulCell(nn.Cell):
|
|
159
|
+
... def __init__(self, param=None, shape=None):
|
|
160
|
+
... super().__init__()
|
|
161
|
+
... if shape is None:
|
|
162
|
+
... shape = [28 * 28, 512]
|
|
163
|
+
... weight_init = HeUniform(math.sqrt(5))
|
|
164
|
+
... self.param = Parameter(initializer(weight_init, shape), name="param")
|
|
165
|
+
... if param is not None:
|
|
166
|
+
... self.param = param
|
|
167
|
+
... self.print = ops.Print()
|
|
168
|
+
... self.matmul = ops.MatMul()
|
|
169
|
+
...
|
|
170
|
+
... def construct(self, x):
|
|
171
|
+
... out = self.matmul(x, self.param)
|
|
172
|
+
... self.print("out is:", out)
|
|
173
|
+
... return out
|
|
174
|
+
>>>
|
|
175
|
+
>>> class Network(nn.Cell):
|
|
176
|
+
... def __init__(self):
|
|
177
|
+
... super().__init__()
|
|
178
|
+
... self.flatten = nn.Flatten()
|
|
179
|
+
... self.layer1 = MatMulCell()
|
|
180
|
+
... self.relu1 = nn.ReLU()
|
|
181
|
+
... self.layer2 = nn.Dense(512, 512)
|
|
182
|
+
... self.relu2 = nn.ReLU()
|
|
183
|
+
... self.layer3 = nn.Dense(512, 10)
|
|
184
|
+
...
|
|
185
|
+
... def construct(self, x):
|
|
186
|
+
... x = self.flatten(x)
|
|
187
|
+
... x = self.layer1(x)
|
|
188
|
+
... x = self.relu1(x)
|
|
189
|
+
... x = self.layer2(x)
|
|
190
|
+
... x = self.relu2(x)
|
|
191
|
+
... logits = self.layer3(x)
|
|
192
|
+
... return logits
|
|
193
|
+
>>>
|
|
194
|
+
>>> net = Network()
|
|
195
|
+
>>> net.layer1.pipeline_stage = 0
|
|
196
|
+
>>> net.relu1.pipeline_stage = 0
|
|
197
|
+
>>> net.layer2.pipeline_stage = 0
|
|
198
|
+
>>> net.relu2.pipeline_stage = 1
|
|
199
|
+
>>> net.layer3.pipeline_stage = 1
|
|
200
|
+
>>>
|
|
201
|
+
>>> def create_dataset(batch_size):
|
|
202
|
+
... dataset_path = os.getenv("DATA_PATH")
|
|
203
|
+
... dataset = ds.MnistDataset(dataset_path)
|
|
204
|
+
... image_transforms = [
|
|
205
|
+
... ds.vision.Rescale(1.0 / 255.0, 0),
|
|
206
|
+
... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
|
|
207
|
+
... ds.vision.HWC2CHW()
|
|
208
|
+
... ]
|
|
209
|
+
... label_transform = ds.transforms.TypeCast(ms.int32)
|
|
210
|
+
... dataset = dataset.map(image_transforms, 'image')
|
|
211
|
+
... dataset = dataset.map(label_transform, 'label')
|
|
212
|
+
... dataset = dataset.batch(batch_size)
|
|
213
|
+
... return dataset
|
|
214
|
+
>>>
|
|
215
|
+
>>> data_set = create_dataset(32)
|
|
216
|
+
>>>
|
|
217
|
+
>>> optimizer = nn.SGD(net.trainable_params(), 1e-2)
|
|
218
|
+
>>> loss_fn = nn.CrossEntropyLoss()
|
|
219
|
+
>>>
|
|
220
|
+
>>> net_with_loss = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
|
|
221
|
+
>>> net_with_loss.set_train()
|
|
222
|
+
>>> model = Model(net_with_loss, optimizer=optimizer)
|
|
223
|
+
>>> ttp_cb = MindIOTTPAdapter("192.168.0.1", 2000, "./ttp_checkpoint/")
|
|
224
|
+
>>> loss_cb = train.LossMonitor(1)
|
|
225
|
+
>>> model.train(1, dataset, callbacks=[ttp_cb, loss_cb])
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
def __init__(self, controller_ip, controller_port, ckpt_save_path):
|
|
229
|
+
super(MindIOTTPAdapter, self).__init__()
|
|
230
|
+
# let it raises errors if not install mindio_ttp package
|
|
231
|
+
from mindio_ttp import framework_ttp as ttp
|
|
232
|
+
self.ttp = ttp
|
|
233
|
+
Validator.check_non_negative_int(controller_port)
|
|
234
|
+
self.has_init = False
|
|
235
|
+
self.enable = False
|
|
236
|
+
mode = context.get_context("mode")
|
|
237
|
+
if context.get_context("device_target") != "Ascend" or mode != context.GRAPH_MODE:
|
|
238
|
+
logger.warning(
|
|
239
|
+
"MindIO adataper only support on Ascend device with GRAPH Mode.")
|
|
240
|
+
return
|
|
241
|
+
if os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT") != "true":
|
|
242
|
+
logger.warning("MindIO adataper need custom switch on.")
|
|
243
|
+
return
|
|
244
|
+
ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH")
|
|
245
|
+
if ttp_lib_path is None or os.path.isfile(ttp_lib_path) is False:
|
|
246
|
+
logger.warning(
|
|
247
|
+
"MindIO adataper switch on, but ttp library path is not correct.")
|
|
248
|
+
return
|
|
249
|
+
self.enable = True
|
|
250
|
+
self._controller_ip = controller_ip
|
|
251
|
+
self._controller_port = controller_port
|
|
252
|
+
self._ckpt_save_path = ckpt_save_path
|
|
253
|
+
|
|
254
|
+
def wrapper_ttp_persist(self, func):
|
|
255
|
+
"""
|
|
256
|
+
This method is used to wrapper TTP exception handler for the input func.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
func (function): train method that need to be wrapper.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Function, if the TTP is enabled, return the encapsulated function,
|
|
263
|
+
otherwise the original function is returned.
|
|
264
|
+
|
|
265
|
+
"""
|
|
266
|
+
if self.enable:
|
|
267
|
+
return self.ttp.ttp_to_persist(func)
|
|
268
|
+
return func
|
|
269
|
+
|
|
270
|
+
def _init_ttp(self, run_context):
|
|
271
|
+
""" Init Mindio TTP, used internal. """
|
|
272
|
+
logger.info("Begin to init ttp.")
|
|
273
|
+
dev_num = _get_device_num()
|
|
274
|
+
|
|
275
|
+
cb_params = run_context.original_args()
|
|
276
|
+
param_layout_dict = cb_params.train_network.parameter_layout_dict
|
|
277
|
+
dp = _get_dp_from_layout(param_layout_dict)
|
|
278
|
+
logger.info("Init TTP with dp: {}.".format(dp))
|
|
279
|
+
|
|
280
|
+
self.ttp.ttp_register_save_ckpt_handler(_save_checkpoint_on_failure)
|
|
281
|
+
self.ttp.ttp_register_rename_handler(_rename_save_result)
|
|
282
|
+
|
|
283
|
+
world_size = dev_num
|
|
284
|
+
cur_rank = get_rank()
|
|
285
|
+
is_odd = len(dp) % 2
|
|
286
|
+
replica = 2 if is_odd else len(dp) // 2
|
|
287
|
+
enable_local_copy = False
|
|
288
|
+
if cur_rank == 0:
|
|
289
|
+
logger.info("Begin to start ttp controller.")
|
|
290
|
+
self.ttp.ttp_init_controller(
|
|
291
|
+
cur_rank, world_size, replica, enable_local_copy)
|
|
292
|
+
self.ttp.ttp_start_controller(
|
|
293
|
+
self._controller_ip, self._controller_port)
|
|
294
|
+
logger.info("Finish start ttp controller.")
|
|
295
|
+
|
|
296
|
+
logger.info("Begin to start ttp processor.")
|
|
297
|
+
self.ttp.ttp_init_processor(cur_rank, dp, len(
|
|
298
|
+
dp), world_size, replica, enable_local_copy)
|
|
299
|
+
self.ttp.ttp_start_processor(
|
|
300
|
+
self._controller_ip, self._controller_port)
|
|
301
|
+
logger.info("Finished start ttp processor.")
|
|
302
|
+
|
|
303
|
+
logger.info("Finish init ttp.")
|
|
304
|
+
|
|
305
|
+
def on_train_step_end(self, run_context):
|
|
306
|
+
"""
|
|
307
|
+
Init TTP Controller only once after first step finished.
|
|
308
|
+
And report status to MindIO TTP after every step finished.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
run_context (RunContext): Context of the train running. Refer to
|
|
312
|
+
:class:`mindspore.train.RunContext` for detail.
|
|
313
|
+
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
if self.enable is False:
|
|
317
|
+
return
|
|
318
|
+
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
319
|
+
if pp_num < 2:
|
|
320
|
+
self.enable = False
|
|
321
|
+
return
|
|
322
|
+
cb_params = run_context.original_args()
|
|
323
|
+
if cb_params.dataset_sink_mode is True and cb_params.sink_size > 1:
|
|
324
|
+
self.enable = False
|
|
325
|
+
return
|
|
326
|
+
if self.has_init is False:
|
|
327
|
+
self.has_init = True
|
|
328
|
+
self._init_ttp(run_context)
|
|
329
|
+
_flush_from_cache(cb_params)
|
|
330
|
+
cur_rank = get_rank()
|
|
331
|
+
append_dict = {}
|
|
332
|
+
append_dict["cur_epoch_num"] = cb_params.cur_epoch_num
|
|
333
|
+
append_dict["cur_step_num"] = int(
|
|
334
|
+
(cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
335
|
+
append_dict["cur_rank"] = cur_rank
|
|
336
|
+
append_dict["batch_num"] = cb_params.batch_num
|
|
337
|
+
append_dict["global_step"] = cb_params.cur_step_num
|
|
338
|
+
|
|
339
|
+
save_params = _convert_net_to_param_list(cb_params.train_network)
|
|
340
|
+
save_params_copy = copy.deepcopy(save_params)
|
|
341
|
+
|
|
342
|
+
logger.info("Set ckpt args to TTP.")
|
|
343
|
+
self.ttp.ttp_set_ckpt_args(
|
|
344
|
+
(self._ckpt_save_path, save_params_copy, append_dict))
|
|
345
|
+
logger.info("Set optimizer finish step status to TTP.")
|
|
346
|
+
self.ttp.ttp_end_updating_os(cb_params.cur_step_num)
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
def load_checkpoint_with_backup(ckpt_file_path, strategy_file_path, net):
|
|
350
|
+
"""
|
|
351
|
+
Load checkpoint into network, and use strategy file to find backup checkpoint file
|
|
352
|
+
when origin checkpoint file not found.
|
|
353
|
+
|
|
354
|
+
Note:
|
|
355
|
+
This API must be called after the communication is initialized because the cluster information
|
|
356
|
+
needs to be obtained internally.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
ckpt_file_path (str): the checkpoint file to be loaded.
|
|
360
|
+
strategy_file_path (str): strategy file path for current rank.
|
|
361
|
+
net (Cell): network that needs to load checkpoint.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
Dict, checkpoint weights after loaded.
|
|
365
|
+
|
|
366
|
+
Raises:
|
|
367
|
+
ValueError: Failed to load the checkpoint file.
|
|
368
|
+
|
|
369
|
+
Examples:
|
|
370
|
+
>>> import numpy as np
|
|
371
|
+
>>> from mindspore import nn
|
|
372
|
+
>>> from mindspore.train import Model, MindIOTTPAdapter
|
|
373
|
+
>>> from mindspore import dataset as ds
|
|
374
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
375
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
|
376
|
+
>>> init()
|
|
377
|
+
>>> ms.set_seed(1)
|
|
378
|
+
>>> class Network(nn.Cell):
|
|
379
|
+
... def __init__(self):
|
|
380
|
+
... super().__init__()
|
|
381
|
+
... self.flatten = nn.Flatten()
|
|
382
|
+
... self.fc = nn.Dense(28*28, 10, weight_init="normal", bias_init="zeros")
|
|
383
|
+
... self.relu = nn.ReLU()
|
|
384
|
+
...
|
|
385
|
+
... def construct(self, x):
|
|
386
|
+
... x = self.flatten(x)
|
|
387
|
+
... logits = self.relu(self.fc(x))
|
|
388
|
+
... return logits
|
|
389
|
+
>>>
|
|
390
|
+
>>> net = Network()
|
|
391
|
+
>>>
|
|
392
|
+
>>> def create_dataset(batch_size):
|
|
393
|
+
... dataset_path = os.getenv("DATA_PATH")
|
|
394
|
+
... rank_id = get_rank()
|
|
395
|
+
... rank_size = get_group_size()
|
|
396
|
+
... dataset = ds.MnistDataset(dataset_path, num_shards=rank_size, shard_id=rank_id)
|
|
397
|
+
... image_transforms = [
|
|
398
|
+
... ds.vision.Rescale(1.0 / 255.0, 0),
|
|
399
|
+
... ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
|
|
400
|
+
... ds.vision.HWC2CHW()
|
|
401
|
+
... ]
|
|
402
|
+
... label_transform = ds.transforms.TypeCast(ms.int32)
|
|
403
|
+
... dataset = dataset.map(image_transforms, 'image')
|
|
404
|
+
... dataset = dataset.map(label_transform, 'label')
|
|
405
|
+
... dataset = dataset.batch(batch_size)
|
|
406
|
+
... return dataset
|
|
407
|
+
>>> data_set = create_dataset(32)
|
|
408
|
+
>>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
|
|
409
|
+
>>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
|
|
410
|
+
>>> param_dict = MindIOTTPAdapter.load_checkpoint_with_backup(ckpt_file, stragegy_file, net)
|
|
411
|
+
>>> data_set.set_init_step(param_dict["global_step"])
|
|
412
|
+
"""
|
|
413
|
+
logger.info("Start load checkpoint with strategy file.")
|
|
414
|
+
try:
|
|
415
|
+
param_dict = ms.load_checkpoint(ckpt_file_path)
|
|
416
|
+
except ValueError as e:
|
|
417
|
+
logger.warning(
|
|
418
|
+
"Loading origin checkpoint file failed, the reason is:{}.".format(str(e)))
|
|
419
|
+
dp = _get_dp_from_layout(strategy_file_path)
|
|
420
|
+
rank = get_rank()
|
|
421
|
+
logger.info(
|
|
422
|
+
"Can't load origin checkpoint file, found dp:{}.".format(dp))
|
|
423
|
+
for i in dp:
|
|
424
|
+
if i == rank:
|
|
425
|
+
continue
|
|
426
|
+
new_ckpt = ckpt_file_path.replace(
|
|
427
|
+
f"/rank_{rank}/", f"/rank_{str(i)}/")
|
|
428
|
+
if not os.path.isfile(new_ckpt):
|
|
429
|
+
continue
|
|
430
|
+
try:
|
|
431
|
+
param_dict = ms.load_checkpoint(new_ckpt)
|
|
432
|
+
except ValueError as e1:
|
|
433
|
+
logger.warning(
|
|
434
|
+
"Loading strategy checkpoint file failed, the reason is:{}.".format(str(e1)))
|
|
435
|
+
param_dict = None
|
|
436
|
+
if param_dict:
|
|
437
|
+
logger.info("Found param dict, load it into network.")
|
|
438
|
+
ms.load_param_into_net(net, param_dict)
|
|
439
|
+
else:
|
|
440
|
+
raise ValueError(
|
|
441
|
+
"Load checkpoint file failed, please check your config is set correctly.")
|
|
442
|
+
logger.info("Finish load checkpoint with strategy file.")
|
|
443
|
+
return param_dict
|
|
@@ -55,13 +55,13 @@ class OnRequestExit(Callback):
|
|
|
55
55
|
>>> import mindspore as ms
|
|
56
56
|
>>>
|
|
57
57
|
>>> # Define the network structure of LeNet5. Refer to
|
|
58
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
58
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
59
59
|
>>> net = LeNet5()
|
|
60
60
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
61
61
|
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
|
62
62
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
|
63
63
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
64
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
64
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
65
65
|
>>> dataset = create_dataset()
|
|
66
66
|
>>> on_request_exit = ms.train.OnRequestExit(file_name='LeNet5')
|
|
67
67
|
>>> model.train(10, dataset, callbacks=on_request_exit)
|
|
@@ -84,13 +84,13 @@ class ReduceLROnPlateau(Callback):
|
|
|
84
84
|
>>> from mindspore import nn
|
|
85
85
|
>>> from mindspore.train import Model, ReduceLROnPlateau
|
|
86
86
|
>>> # Define the network structure of LeNet5. Refer to
|
|
87
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
87
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
88
88
|
>>> net = LeNet5()
|
|
89
89
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
90
90
|
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
|
91
91
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"acc"})
|
|
92
92
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
93
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
93
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
94
94
|
>>> dataset = create_dataset()
|
|
95
95
|
>>> cb = ReduceLROnPlateau(monitor="acc", patience=3, verbose=True)
|
|
96
96
|
>>> model.fit(10, dataset, callbacks=cb)
|
|
@@ -105,11 +105,11 @@ class SummaryCollector(Callback):
|
|
|
105
105
|
training computational graph is collected. Default: ``True`` .
|
|
106
106
|
- collect_train_lineage (bool): Whether to collect lineage data for the training phase,
|
|
107
107
|
this field will be displayed on the `lineage page \
|
|
108
|
-
<https://www.mindspore.cn/mindinsight/docs/en/
|
|
108
|
+
<https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_
|
|
109
109
|
of MindInsight. Default: ``True`` .
|
|
110
110
|
- collect_eval_lineage (bool): Whether to collect lineage data for the evaluation phase,
|
|
111
111
|
this field will be displayed on the `lineage page
|
|
112
|
-
<https://www.mindspore.cn/mindinsight/docs/en/
|
|
112
|
+
<https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_
|
|
113
113
|
of MindInsight. Default: ``True`` .
|
|
114
114
|
- collect_input_data (bool): Whether to collect dataset for each training.
|
|
115
115
|
Currently only image data is supported.
|
|
@@ -134,7 +134,7 @@ class SummaryCollector(Callback):
|
|
|
134
134
|
Optional: epoch/step.
|
|
135
135
|
- create_landscape (dict): Select how to create loss landscape.
|
|
136
136
|
Training process loss landscape(train) and training result loss landscape(result).
|
|
137
|
-
Default: {"train": True, "result": True}
|
|
137
|
+
Default: ``{"train": True, "result": True}``. Optional: ``True`` / ``False`` .
|
|
138
138
|
- num_samples (int): The size of the dataset used to create the loss landscape.
|
|
139
139
|
For example, in image dataset, You can set num_samples is 128,
|
|
140
140
|
which means that 128 images are used to create loss landscape.
|
|
@@ -150,7 +150,7 @@ class SummaryCollector(Callback):
|
|
|
150
150
|
False: it means that after specified data is set, only the specified data is collected,
|
|
151
151
|
and the others are not collected. Default: ``True`` .
|
|
152
152
|
custom_lineage_data (Union[dict, None]): Allows you to customize the data and present it on the MingInsight
|
|
153
|
-
`lineage page <https://www.mindspore.cn/mindinsight/docs/en/
|
|
153
|
+
`lineage page <https://www.mindspore.cn/mindinsight/docs/en/master/lineage_and_scalars_comparison.html>`_ .
|
|
154
154
|
In the custom data, the type of the key supports str, and the type of value supports str, int
|
|
155
155
|
and float. Default: ``None`` , it means there is no custom data.
|
|
156
156
|
collect_tensor_freq (Optional[int]): The same semantics as the `collect_freq`, but controls TensorSummary only.
|
|
@@ -190,10 +190,10 @@ class SummaryCollector(Callback):
|
|
|
190
190
|
... ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
|
|
191
191
|
... mnist_dataset_dir = '/path/to/mnist_dataset_directory'
|
|
192
192
|
... # Create the dataset taking MNIST as an example. Refer to
|
|
193
|
-
... # https://gitee.com/mindspore/docs/blob/
|
|
193
|
+
... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
194
194
|
... ds_train = create_dataset()
|
|
195
195
|
... # Define the network structure of LeNet5. Refer to
|
|
196
|
-
... # https://gitee.com/mindspore/docs/blob/
|
|
196
|
+
... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
197
197
|
... network = LeNet5(10)
|
|
198
198
|
... net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
|
199
199
|
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
|
@@ -1156,7 +1156,7 @@ class SummaryCollector(Callback):
|
|
|
1156
1156
|
except TypeError as exc:
|
|
1157
1157
|
logger.warning("Summary cannot collect the type of metrics, currently support type: dict, list, tuple, "
|
|
1158
1158
|
"str, int, float, bool and None. %s.", str(exc))
|
|
1159
|
-
|
|
1159
|
+
self._parse_dataset(cb_params, eval_lineage)
|
|
1160
1160
|
|
|
1161
1161
|
eval_lineage_message = self._package_eval_lineage_message(eval_lineage)
|
|
1162
1162
|
self._record.add_value(PluginEnum.EVAL_LINEAGE.value, 'eval_lineage', eval_lineage_message)
|
|
@@ -30,7 +30,7 @@ class TimeMonitor(Callback):
|
|
|
30
30
|
if the program get `batch_num` during training, `data_size` will be set to `batch_num`,
|
|
31
31
|
otherwise `data_size` will be used. Default: ``None`` .
|
|
32
32
|
|
|
33
|
-
data_time (bool): Whether to
|
|
33
|
+
data_time (bool): Whether to show the average time of fetching data in Host.
|
|
34
34
|
Note that data fetch and network compute are processed sequentially in non dataset sink mode, while
|
|
35
35
|
they are asynchronous in dataset sink mode. Default: ``False`` .
|
|
36
36
|
|
|
@@ -43,13 +43,13 @@ class TimeMonitor(Callback):
|
|
|
43
43
|
>>> from mindspore.train import Model, TimeMonitor
|
|
44
44
|
>>>
|
|
45
45
|
>>> # Define the network structure of LeNet5. Refer to
|
|
46
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
46
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
47
47
|
>>> net = LeNet5()
|
|
48
48
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
49
49
|
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
|
50
50
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
|
51
51
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
52
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
52
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
53
53
|
>>> dataset = create_dataset()
|
|
54
54
|
>>> time_monitor = TimeMonitor()
|
|
55
55
|
>>> model.train(10, dataset, callbacks=time_monitor)
|
mindspore/train/data_sink.py
CHANGED
|
@@ -130,10 +130,11 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
|
|
|
130
130
|
A wrapper function to generate a function for the input function.
|
|
131
131
|
|
|
132
132
|
Note:
|
|
133
|
-
When using data sinking, the dataset will be automatically
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
133
|
+
When using data sinking, the dataset will be automatically looped to the device. The device side can cache up
|
|
134
|
+
to 100 batches of data and occupy no more than 2GB of memory. At this time, only the number of steps for each
|
|
135
|
+
sinking `sink_size` needs to be considered. `sink_size` defaults to ``1``, indicating that each epoch only
|
|
136
|
+
takes one batch of data from the cache for training and outputs a loss. If `sink_size` is greater than 1, each
|
|
137
|
+
epoch takes out `sink_size` batches of data from the cache for training and outputs a loss.
|
|
137
138
|
|
|
138
139
|
Args:
|
|
139
140
|
fn (Function): The Python function that will be run with dataset.
|
|
@@ -145,7 +146,7 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
|
|
|
145
146
|
input_signature (Union[Tensor, List or Tuple of Tensors]): The Tensor which describes the input arguments.
|
|
146
147
|
The shape and dtype of the Tensor will be supplied to this function. If input_signature is specified,
|
|
147
148
|
each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept `**kwargs`. The shape
|
|
148
|
-
and dtype of actual inputs should keep the same as input_signature
|
|
149
|
+
and dtype of actual inputs should keep the same as `input_signature`. Otherwise, TypeError will be raised.
|
|
149
150
|
Default: ``None`` .
|
|
150
151
|
|
|
151
152
|
Returns:
|