mindspore 2.2.14__cp38-cp38-manylinux1_x86_64.whl → 2.3.0rc2__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -4
- mindspore/_akg/akg/composite/build_module.py +155 -11
- mindspore/_akg/akg/config/repository.json +38 -0
- mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
- mindspore/_akg/akg/utils/gen_random.py +14 -8
- mindspore/_akg/akg/utils/op_dsl.py +11 -0
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +18 -8
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +78 -0
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +229 -0
- mindspore/_extends/parse/parser.py +155 -59
- mindspore/_extends/parse/resources.py +40 -7
- mindspore/_extends/parse/standard_method.py +127 -206
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _profiler.py} +13 -16
- mindspore/amp.py +24 -18
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +7 -3
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +145 -50
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +9 -6
- mindspore/common/dump.py +5 -4
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +33 -13
- mindspore/common/lazy_inline.py +58 -17
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/parameter.py +24 -4
- mindspore/common/recompute.py +247 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +251 -18
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +391 -465
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/management.py +53 -38
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +176 -55
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +72 -38
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +7 -7
- mindspore/dataset/engine/datasets_vision.py +75 -71
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/transforms.py +2264 -514
- mindspore/dataset/vision/utils.py +40 -9
- mindspore/dataset/vision/validators.py +7 -1
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +34 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +1 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/mindapi/base/format.h +125 -23
- mindspore/include/mindapi/base/types.h +12 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6318 -1760
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +74 -56
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +131 -174
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/activation.py +79 -90
- mindspore/nn/layer/basic.py +113 -81
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +71 -71
- mindspore/nn/layer/embedding.py +105 -44
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +13 -9
- mindspore/nn/layer/rnn_cells.py +5 -15
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +43 -64
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +6 -5
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +33 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +7 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +54 -60
- mindspore/numpy/utils.py +3 -0
- mindspore/ops/__init__.py +5 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +14 -18
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +137 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +101 -57
- mindspore/ops/_vmap/vmap_nn_ops.py +230 -97
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +205 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +257 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +171 -0
- mindspore/ops/auto_generate/gen_extend_func.py +404 -0
- mindspore/ops/auto_generate/gen_ops_def.py +5653 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +11623 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +359 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +118 -17
- mindspore/ops/composite/math_ops.py +9 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +168 -602
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +24 -133
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +54 -0
- mindspore/ops/extend/array_func.py +259 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +384 -0
- mindspore/ops/function/__init__.py +37 -12
- mindspore/ops/function/array_func.py +702 -1867
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +24 -17
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +639 -2531
- mindspore/ops/function/nn_func.py +1274 -832
- mindspore/ops/function/other_func.py +4 -5
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +84 -71
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +57 -63
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +19 -20
- mindspore/ops/operations/_grad_ops.py +20 -828
- mindspore/ops/operations/_inner_ops.py +180 -288
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/array_ops.py +83 -2697
- mindspore/ops/operations/comm_ops.py +38 -46
- mindspore/ops/operations/custom_ops.py +14 -96
- mindspore/ops/operations/debug_ops.py +100 -31
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +3 -38
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +1716 -0
- mindspore/ops/operations/math_ops.py +581 -4629
- mindspore/ops/operations/nn_ops.py +260 -1941
- mindspore/ops/operations/other_ops.py +50 -42
- mindspore/ops/operations/random_ops.py +3 -52
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +196 -96
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +257 -0
- mindspore/ops_generate/arg_handler.py +171 -0
- mindspore/ops_generate/gen_aclnn_implement.py +266 -0
- mindspore/ops_generate/gen_ops.py +1062 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +939 -0
- mindspore/ops_generate/gen_utils.py +188 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +349 -0
- mindspore/ops_generate/template.py +238 -0
- mindspore/parallel/__init__.py +6 -4
- mindspore/parallel/_auto_parallel_context.py +52 -2
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +29 -13
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +19 -7
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +147 -6
- mindspore/parallel/algo_parameter_config.py +6 -6
- mindspore/parallel/checkpoint_transform.py +180 -24
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +345 -0
- mindspore/parallel/cluster/process_entity/_utils.py +116 -0
- mindspore/parallel/cluster/run.py +139 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +99 -2
- mindspore/profiler/common/util.py +20 -0
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +109 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
- mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +294 -133
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +305 -167
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +121 -35
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +635 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +107 -144
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/scipy/__init__.py +2 -1
- mindspore/scipy/fft.py +133 -0
- mindspore/scipy/linalg.py +140 -55
- mindspore/scipy/ops.py +15 -71
- mindspore/scipy/ops_grad.py +5 -34
- mindspore/scipy/optimize/line_search.py +2 -2
- mindspore/scipy/optimize/minimize.py +1 -1
- mindspore/train/__init__.py +3 -2
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +4 -4
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +39 -13
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +14 -8
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/data_sink.py +1 -1
- mindspore/train/dataset_helper.py +18 -4
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +24 -22
- mindspore/train/serialization.py +257 -133
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +534 -1066
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/config/super_bar_config.json +0 -544
- mindspore/gen_ops.py +0 -273
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -50,9 +50,11 @@ from mindspore.common.api import _generate_branch_control_input
|
|
|
50
50
|
from mindspore.common.initializer import initializer, One
|
|
51
51
|
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
52
52
|
from mindspore.common.tensor import Tensor
|
|
53
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
53
54
|
from mindspore.common._utils import is_shape_unknown
|
|
54
55
|
from mindspore.communication.management import get_rank, get_group_size
|
|
55
56
|
from mindspore.experimental import MapParameter
|
|
57
|
+
from mindspore.ops import Cast
|
|
56
58
|
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|
57
59
|
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
|
58
60
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
@@ -61,21 +63,20 @@ from mindspore.parallel._parallel_serialization import _convert_to_list, _conver
|
|
|
61
63
|
_restore_group_info_list
|
|
62
64
|
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
63
65
|
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
66
|
+
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
64
67
|
from mindspore.train._utils import read_proto
|
|
65
68
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
66
69
|
split_mindir, split_dynamic_mindir
|
|
67
70
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
68
|
-
from ..ops.operations import Cast
|
|
69
71
|
|
|
70
72
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
71
73
|
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
|
72
74
|
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
|
73
|
-
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16}
|
|
75
|
+
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
|
|
74
76
|
|
|
75
77
|
tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
|
|
76
78
|
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
77
|
-
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"
|
|
78
|
-
"BFloat16": np.float32}
|
|
79
|
+
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
79
80
|
|
|
80
81
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
81
82
|
|
|
@@ -96,6 +97,17 @@ INT_64_MAX = 9223372036854775807
|
|
|
96
97
|
cpu_cast = Cast().set_device("CPU")
|
|
97
98
|
|
|
98
99
|
|
|
100
|
+
class ParamDictFuture:
|
|
101
|
+
def __init__(self, executor, param_dict_future):
|
|
102
|
+
self.executor = executor
|
|
103
|
+
self.param_dict_future = param_dict_future
|
|
104
|
+
|
|
105
|
+
def result(self):
|
|
106
|
+
param_dict = self.param_dict_future.result()
|
|
107
|
+
self.executor.shutdown()
|
|
108
|
+
return param_dict
|
|
109
|
+
|
|
110
|
+
|
|
99
111
|
def _special_process_par(par, new_par):
|
|
100
112
|
"""
|
|
101
113
|
Processes the special condition.
|
|
@@ -242,21 +254,21 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
242
254
|
continue
|
|
243
255
|
if value[0] == "offload_parameter":
|
|
244
256
|
new_value = value[1:]
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
else:
|
|
248
|
-
new_value[2] = value[3].asnumpy().reshape(-1)
|
|
249
|
-
_write_parameter_data(name, new_value, f, enc_key, plain_data)
|
|
257
|
+
new_value[2] = value[3]
|
|
258
|
+
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
|
|
250
259
|
_offload_if_config(value[3])
|
|
251
260
|
continue
|
|
252
|
-
if value[
|
|
253
|
-
|
|
261
|
+
if value[1] == "str":
|
|
262
|
+
_write_parameter_data(name, value, f, enc_key, plain_data)
|
|
254
263
|
continue
|
|
255
|
-
if isinstance(value[2],
|
|
264
|
+
if isinstance(value[2], np.ndarray):
|
|
265
|
+
_write_parameter_data(name, value, f, enc_key, plain_data)
|
|
266
|
+
continue
|
|
267
|
+
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
256
268
|
_write_hugeparameter(name, value, f)
|
|
257
269
|
continue
|
|
258
270
|
|
|
259
|
-
|
|
271
|
+
_write_parameter_bytes_data(name, value, f, enc_key, plain_data)
|
|
260
272
|
|
|
261
273
|
if enc_key is not None:
|
|
262
274
|
plain_data.seek(0)
|
|
@@ -286,21 +298,6 @@ def _write_random_seed(name, value, f):
|
|
|
286
298
|
f.write(checkpoint_list.SerializeToString())
|
|
287
299
|
|
|
288
300
|
|
|
289
|
-
def _write_bfloat16_data(name, value, f, enc_key, plain_data):
|
|
290
|
-
"""Write bfloat16 data into protobuf file"""
|
|
291
|
-
checkpoint_list = Checkpoint()
|
|
292
|
-
param_value = checkpoint_list.value.add()
|
|
293
|
-
param_value.tag = name
|
|
294
|
-
param_tensor = param_value.tensor
|
|
295
|
-
param_tensor.dims.extend(value[1])
|
|
296
|
-
param_tensor.tensor_type = value[2]
|
|
297
|
-
param_tensor.tensor_content = value[3].get_bytes()
|
|
298
|
-
if enc_key is None:
|
|
299
|
-
f.write(checkpoint_list.SerializeToString())
|
|
300
|
-
else:
|
|
301
|
-
plain_data.write(checkpoint_list.SerializeToString())
|
|
302
|
-
|
|
303
|
-
|
|
304
301
|
def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
305
302
|
"""Write parameter data into protobuf file."""
|
|
306
303
|
data_size = value[2].nbytes / 1024
|
|
@@ -325,6 +322,26 @@ def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
|
325
322
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
326
323
|
|
|
327
324
|
|
|
325
|
+
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data):
|
|
326
|
+
"""Write parameter bytes data into protobuf file."""
|
|
327
|
+
bytes_value = value[2].get_bytes()
|
|
328
|
+
chunk_size = 1024 * SLICE_SIZE
|
|
329
|
+
|
|
330
|
+
for i in range(0, len(bytes_value), chunk_size):
|
|
331
|
+
checkpoint_list = Checkpoint()
|
|
332
|
+
param_value = checkpoint_list.value.add()
|
|
333
|
+
param_value.tag = name
|
|
334
|
+
param_tensor = param_value.tensor
|
|
335
|
+
param_tensor.dims.extend(value[0])
|
|
336
|
+
param_tensor.tensor_type = value[1]
|
|
337
|
+
param_tensor.tensor_content = bytes_value[i:i + chunk_size]
|
|
338
|
+
|
|
339
|
+
if enc_key is None:
|
|
340
|
+
f.write(checkpoint_list.SerializeToString())
|
|
341
|
+
else:
|
|
342
|
+
plain_data.write(checkpoint_list.SerializeToString())
|
|
343
|
+
|
|
344
|
+
|
|
328
345
|
def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
329
346
|
"""Write map parameter into protobuf file."""
|
|
330
347
|
while True:
|
|
@@ -420,7 +437,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
420
437
|
>>> import mindspore as ms
|
|
421
438
|
>>>
|
|
422
439
|
>>> # Define the network structure of LeNet5. Refer to
|
|
423
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
440
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
424
441
|
>>> net = LeNet5()
|
|
425
442
|
>>> ms.save_checkpoint(net, "./lenet.ckpt",
|
|
426
443
|
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
@@ -440,7 +457,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
440
457
|
|
|
441
458
|
Tutorial Examples:
|
|
442
459
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
443
|
-
<https://mindspore.cn/tutorials/en/
|
|
460
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
444
461
|
"""
|
|
445
462
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
446
463
|
integrated_save = Validator.check_bool(integrated_save)
|
|
@@ -479,10 +496,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
479
496
|
elif param["data"][0] == "offload_parameter":
|
|
480
497
|
data_list[key].append("offload_parameter")
|
|
481
498
|
_save_param_list_data(data_list, key, param)
|
|
482
|
-
elif param["data"][0] == "BFloat16_tensor":
|
|
483
|
-
data_list[key].append("BFloat16_tensor")
|
|
484
|
-
_save_param_list_data(data_list, key, param)
|
|
485
|
-
continue
|
|
486
499
|
|
|
487
500
|
if isinstance(param["data"], str):
|
|
488
501
|
data_list[key].append([0])
|
|
@@ -492,28 +505,13 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
492
505
|
else:
|
|
493
506
|
if isinstance(param["data"], Parameter):
|
|
494
507
|
param["data"].init_data()
|
|
495
|
-
if isinstance(param["data"], Tensor) and param["data"].dtype == mstype.bfloat16:
|
|
496
|
-
data_list[key].append("BFloat16_tensor")
|
|
497
|
-
dims = []
|
|
498
|
-
for dim in param["data"].shape:
|
|
499
|
-
dims.append(dim)
|
|
500
|
-
data_list[key].append(dims)
|
|
501
|
-
data_list[key].append("BFloat16")
|
|
502
|
-
data_list[key].append(cpu_cast(param["data"], mstype.float32))
|
|
503
|
-
continue
|
|
504
508
|
dims = []
|
|
505
|
-
|
|
506
|
-
dims.append(
|
|
507
|
-
else:
|
|
508
|
-
for dim in param['data'].shape:
|
|
509
|
-
dims.append(dim)
|
|
509
|
+
for dim in param['data'].shape:
|
|
510
|
+
dims.append(dim)
|
|
510
511
|
data_list[key].append(dims)
|
|
511
512
|
tensor_type = str(param["data"].dtype)
|
|
512
513
|
data_list[key].append(tensor_type)
|
|
513
|
-
|
|
514
|
-
data = cpu_cast(param["data"], mstype.float32).asnumpy().reshape(-1)
|
|
515
|
-
else:
|
|
516
|
-
data = param["data"].asnumpy().reshape(-1)
|
|
514
|
+
data = param["data"]
|
|
517
515
|
data_list[key].append(data)
|
|
518
516
|
|
|
519
517
|
if async_save:
|
|
@@ -532,7 +530,21 @@ def _convert_list_to_param_list(save_obj, choice_func):
|
|
|
532
530
|
if not save_obj:
|
|
533
531
|
return param_list
|
|
534
532
|
if isinstance(save_obj[0], dict):
|
|
535
|
-
|
|
533
|
+
for param in save_obj:
|
|
534
|
+
if isinstance(param, dict) and "name" in param and "data" in param:
|
|
535
|
+
if not isinstance(param["name"], str):
|
|
536
|
+
raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict "
|
|
537
|
+
f"should be string, but got {type(param['name'])}.")
|
|
538
|
+
if not isinstance(param["data"], Tensor):
|
|
539
|
+
raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict "
|
|
540
|
+
f"should be parameter, but got {type(param['data'])}.")
|
|
541
|
+
if choice_func is not None and not choice_func(param["name"]):
|
|
542
|
+
continue
|
|
543
|
+
each_param = {"name": param["name"], "data": param["data"]}
|
|
544
|
+
param_list.append(each_param)
|
|
545
|
+
else:
|
|
546
|
+
raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should "
|
|
547
|
+
f"have key values 'name' and 'value', but got {type(param)} and {param}.")
|
|
536
548
|
else:
|
|
537
549
|
for param in save_obj:
|
|
538
550
|
if isinstance(param, Parameter):
|
|
@@ -585,6 +597,7 @@ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
|
585
597
|
|
|
586
598
|
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
587
599
|
"""Convert nn.Cell to param_list."""
|
|
600
|
+
sync_pipeline_shared_parameters(save_obj)
|
|
588
601
|
param_list = []
|
|
589
602
|
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
590
603
|
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
@@ -597,7 +610,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
597
610
|
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
598
611
|
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
599
612
|
param_list.append({"name": "random_op", "data": random_byte})
|
|
600
|
-
|
|
613
|
+
append_dict.pop("random_op")
|
|
601
614
|
for (key, value) in param_dict.items():
|
|
602
615
|
each_param = {"name": key}
|
|
603
616
|
if isinstance(value, MapParameter):
|
|
@@ -619,18 +632,13 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
619
632
|
param_data.append(param_tensor.shape)
|
|
620
633
|
param_data.append(str(param_tensor.dtype))
|
|
621
634
|
param_data.append(value.key)
|
|
622
|
-
elif value.data.dtype == mstype.bfloat16:
|
|
623
|
-
param_data = ["BFloat16_tensor"]
|
|
624
|
-
param_data.append(cpu_cast(value.data, mstype.float32))
|
|
625
|
-
param_data.append(value.data.shape)
|
|
626
|
-
param_data.append("BFloat16")
|
|
627
|
-
param_data.append(value.key)
|
|
628
635
|
else:
|
|
629
|
-
param_data =
|
|
636
|
+
param_data = value.data
|
|
630
637
|
|
|
631
638
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
632
639
|
# which should be combined before saving
|
|
633
640
|
if key in parameter_layout_dict:
|
|
641
|
+
param_data = Tensor(value.data)
|
|
634
642
|
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
635
643
|
integrated_save)
|
|
636
644
|
|
|
@@ -699,13 +707,13 @@ def load(file_name, **kwargs):
|
|
|
699
707
|
- dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
|
|
700
708
|
- dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
|
|
701
709
|
|
|
702
|
-
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: 'AES-GCM'
|
|
710
|
+
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
|
|
703
711
|
- For details of using the customized decryption, please check the `tutorial
|
|
704
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
712
|
+
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
705
713
|
|
|
706
714
|
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
707
715
|
`obfuscate_model()
|
|
708
|
-
<https://www.mindspore.cn/docs/en/
|
|
716
|
+
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
709
717
|
|
|
710
718
|
Returns:
|
|
711
719
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
@@ -735,7 +743,7 @@ def load(file_name, **kwargs):
|
|
|
735
743
|
|
|
736
744
|
Tutorial Examples:
|
|
737
745
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
738
|
-
<https://mindspore.cn/tutorials/en/
|
|
746
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
739
747
|
"""
|
|
740
748
|
if not isinstance(file_name, str):
|
|
741
749
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
@@ -776,7 +784,7 @@ def load(file_name, **kwargs):
|
|
|
776
784
|
return graph
|
|
777
785
|
|
|
778
786
|
|
|
779
|
-
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=
|
|
787
|
+
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True):
|
|
780
788
|
"""
|
|
781
789
|
Auto Split MindIR.
|
|
782
790
|
|
|
@@ -784,10 +792,10 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=F
|
|
|
784
792
|
|
|
785
793
|
Args:
|
|
786
794
|
file_name (str): MindIR file name.
|
|
787
|
-
device_num (int): device number.
|
|
788
|
-
rank_id (int): rank id.
|
|
789
|
-
dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
|
|
790
|
-
sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
|
|
795
|
+
device_num (int): device number. Default: '8'.
|
|
796
|
+
rank_id (int): rank id. Default: '0'.
|
|
797
|
+
dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'.
|
|
798
|
+
sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'.
|
|
791
799
|
|
|
792
800
|
Raises:
|
|
793
801
|
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
@@ -909,13 +917,14 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
909
917
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
910
918
|
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
911
919
|
Reference to 'my_func()' in
|
|
912
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/
|
|
920
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
913
921
|
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
914
922
|
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
915
923
|
when loading obfuscated model.
|
|
916
924
|
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
917
925
|
structure of obfuscated models corresponding to different random seeds is different. If
|
|
918
|
-
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell
|
|
926
|
+
`obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
|
|
927
|
+
interface when loading
|
|
919
928
|
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
920
929
|
be set, and the latter mode would be applied if both of them are set.
|
|
921
930
|
|
|
@@ -923,7 +932,7 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
923
932
|
|
|
924
933
|
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
|
925
934
|
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
|
|
926
|
-
|
|
935
|
+
Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
|
|
927
936
|
|
|
928
937
|
Raises:
|
|
929
938
|
TypeError: If `obf_config` is not a dict.
|
|
@@ -934,11 +943,15 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
934
943
|
ValueError: If `obf_ratio` is not provided in `obf_config`.
|
|
935
944
|
ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
|
|
936
945
|
ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
|
|
937
|
-
ValueError: If `original_model_path`
|
|
946
|
+
ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
|
|
938
947
|
|
|
939
948
|
Examples:
|
|
940
949
|
>>> import mindspore as ms
|
|
941
950
|
>>> import mindspore.nn as nn
|
|
951
|
+
>>> import numpy as np
|
|
952
|
+
>>> # Download ori_net.mindir
|
|
953
|
+
>>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
|
|
954
|
+
>>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
|
|
942
955
|
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
943
956
|
... 'save_model_path': "./obf_net",
|
|
944
957
|
... 'model_inputs': [input1, ],
|
|
@@ -1076,7 +1089,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1076
1089
|
|
|
1077
1090
|
Tutorial Examples:
|
|
1078
1091
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1079
|
-
<https://mindspore.cn/tutorials/en/
|
|
1092
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1080
1093
|
"""
|
|
1081
1094
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1082
1095
|
specify_prefix = _check_prefix(specify_prefix)
|
|
@@ -1119,31 +1132,20 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1119
1132
|
if data_type == 'str':
|
|
1120
1133
|
str_length = int(len(data) / 4)
|
|
1121
1134
|
np_type = np_type + str(str_length)
|
|
1122
|
-
|
|
1123
|
-
dims = element.tensor.dims
|
|
1124
|
-
param_data = np.frombuffer(data, np_type)
|
|
1125
|
-
param_data = param_data.reshape(list(dims))
|
|
1126
|
-
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1127
|
-
parameter_dict[element.tag] = parameter
|
|
1128
|
-
continue
|
|
1129
|
-
element_data = np.frombuffer(data, np_type)
|
|
1130
|
-
param_data_list.append(element_data)
|
|
1135
|
+
param_data_list.append(data)
|
|
1131
1136
|
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1132
1137
|
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
1133
1138
|
new_data = b"".join(param_data_list)
|
|
1134
|
-
param_data = np.frombuffer(new_data, np_type)
|
|
1135
1139
|
param_data_list.clear()
|
|
1136
1140
|
dims = element.tensor.dims
|
|
1137
1141
|
if dims == [0] and data_type == 'str':
|
|
1138
|
-
|
|
1142
|
+
str_value = np.frombuffer(new_data, np_type)
|
|
1143
|
+
parameter_dict[element.tag] = str(str_value[0])
|
|
1139
1144
|
else:
|
|
1140
|
-
if dims == [0]
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
if dims not in ([0], [1]):
|
|
1145
|
-
param_data = param_data.reshape(list(dims))
|
|
1146
|
-
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1145
|
+
if dims == [0]:
|
|
1146
|
+
dims = []
|
|
1147
|
+
param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
|
|
1148
|
+
parameter = Parameter(param_data, name=element.tag)
|
|
1147
1149
|
parameter_dict[element.tag] = parameter
|
|
1148
1150
|
_offload_if_config(parameter)
|
|
1149
1151
|
|
|
@@ -1168,6 +1170,86 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1168
1170
|
return parameter_dict
|
|
1169
1171
|
|
|
1170
1172
|
|
|
1173
|
+
def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None,
|
|
1174
|
+
dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
|
|
1175
|
+
"""
|
|
1176
|
+
Load checkpoint info from a specified file asyncly.
|
|
1177
|
+
|
|
1178
|
+
.. warning::
|
|
1179
|
+
This is an experimental API that is subject to change or deletion.
|
|
1180
|
+
|
|
1181
|
+
Note:
|
|
1182
|
+
- `specify_prefix` and `filter_prefix` do not affect each other.
|
|
1183
|
+
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1184
|
+
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1185
|
+
`choice_func` is recommended instead.
|
|
1186
|
+
And using either of those two args will override `choice_func` at the same time.
|
|
1187
|
+
|
|
1188
|
+
Args:
|
|
1189
|
+
ckpt_file_name (str): Checkpoint file name.
|
|
1190
|
+
net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
|
|
1191
|
+
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1192
|
+
parameter into net when parameter name's suffix in checkpoint file is the
|
|
1193
|
+
same as the parameter in the network. When the types are inconsistent
|
|
1194
|
+
perform type conversion on the parameters of the same type, such as float32
|
|
1195
|
+
to float16. Default: ``False`` .
|
|
1196
|
+
filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
|
|
1197
|
+
starting with the `filter_prefix` will not be loaded. Default: ``None`` .
|
|
1198
|
+
dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
|
|
1199
|
+
the decryption is not required. Default: ``None`` .
|
|
1200
|
+
dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies
|
|
1201
|
+
the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"``
|
|
1202
|
+
and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` .
|
|
1203
|
+
specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
|
|
1204
|
+
starting with the specify_prefix will be loaded. Default: ``None`` .
|
|
1205
|
+
choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type
|
|
1206
|
+
string, and the return value is a bool. If returns ``True`` , the Parameter
|
|
1207
|
+
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1208
|
+
matches the custom condition will be removed. Default: ``None`` .
|
|
1209
|
+
|
|
1210
|
+
Returns:
|
|
1211
|
+
A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result.
|
|
1212
|
+
|
|
1213
|
+
Raises:
|
|
1214
|
+
ValueError: Checkpoint file's format is incorrect.
|
|
1215
|
+
ValueError: Parameter's dict is None after load checkpoint file.
|
|
1216
|
+
TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
|
|
1217
|
+
|
|
1218
|
+
Examples:
|
|
1219
|
+
>>> import mindspore
|
|
1220
|
+
>>> from mindspore import nn
|
|
1221
|
+
>>> from mindspore.train import Model
|
|
1222
|
+
>>> from mindspore.amp import FixedLossScaleManager
|
|
1223
|
+
>>> from mindspore import context
|
|
1224
|
+
>>> from mindspore import load_checkpoint_async
|
|
1225
|
+
>>> from mindspore import load_param_into_net
|
|
1226
|
+
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
1227
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1228
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1229
|
+
>>> dataset = create_dataset()
|
|
1230
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1231
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1232
|
+
>>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1233
|
+
>>> net = LeNet5()
|
|
1234
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
|
1235
|
+
>>> loss_scale_manager = FixedLossScaleManager()
|
|
1236
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1237
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
1238
|
+
... loss_scale_manager=loss_scale_manager)
|
|
1239
|
+
>>> pd_future = load_checkpoint_async(ckpt_file)
|
|
1240
|
+
>>> model.build(train_dataset=dataset, epoch=2)
|
|
1241
|
+
>>> param_dict = pd_future.result()
|
|
1242
|
+
>>> load_param_into_net(net, param_dict)
|
|
1243
|
+
>>> model.train(2, dataset)
|
|
1244
|
+
>>> print("param dict len: ", len(param_dict), flush=True)
|
|
1245
|
+
"""
|
|
1246
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
1247
|
+
executor = ThreadPoolExecutor(max_workers=2)
|
|
1248
|
+
param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
|
|
1249
|
+
dec_key, dec_mode, specify_prefix, choice_func)
|
|
1250
|
+
return ParamDictFuture(executor, param_dict_future)
|
|
1251
|
+
|
|
1252
|
+
|
|
1171
1253
|
def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
1172
1254
|
map_shape_list, parameter_dict):
|
|
1173
1255
|
"""load map parameter."""
|
|
@@ -1303,8 +1385,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1303
1385
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1304
1386
|
|
|
1305
1387
|
Returns:
|
|
1306
|
-
param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1307
|
-
ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
1388
|
+
- param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1389
|
+
- ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
1308
1390
|
|
|
1309
1391
|
Raises:
|
|
1310
1392
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
|
@@ -1313,7 +1395,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1313
1395
|
>>> import mindspore as ms
|
|
1314
1396
|
>>>
|
|
1315
1397
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1316
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
1398
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1317
1399
|
>>> net = LeNet5()
|
|
1318
1400
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1319
1401
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
@@ -1323,7 +1405,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1323
1405
|
|
|
1324
1406
|
Tutorial Examples:
|
|
1325
1407
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1326
|
-
<https://mindspore.cn/tutorials/en/
|
|
1408
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1327
1409
|
"""
|
|
1328
1410
|
if not isinstance(net, nn.Cell):
|
|
1329
1411
|
logger.critical("Failed to combine the net and the parameters.")
|
|
@@ -1369,18 +1451,13 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1369
1451
|
if param_not_load and not strict_load:
|
|
1370
1452
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
|
|
1371
1453
|
|
|
1372
|
-
logger.debug("Params not matched(in net but not in parameter_dict):")
|
|
1373
|
-
for param_name in param_not_load:
|
|
1374
|
-
logger.debug("%s", param_name)
|
|
1375
|
-
|
|
1376
1454
|
logger.info("Loading parameters into net is finished.")
|
|
1377
1455
|
if param_not_load:
|
|
1378
1456
|
logger.warning("For 'load_param_into_net', "
|
|
1379
1457
|
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1380
1458
|
"'parameter_dict', please check whether the network structure is consistent "
|
|
1381
1459
|
"when training and loading checkpoint.".format(len(param_not_load)))
|
|
1382
|
-
|
|
1383
|
-
logger.warning("{} is not loaded.".format(param_name))
|
|
1460
|
+
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1384
1461
|
return param_not_load, ckpt_not_load
|
|
1385
1462
|
|
|
1386
1463
|
|
|
@@ -1494,6 +1571,23 @@ def _save_graph(network, file_name):
|
|
|
1494
1571
|
f.write(graph_pb)
|
|
1495
1572
|
|
|
1496
1573
|
|
|
1574
|
+
def _reshape_tensor(tensor, dst_shape):
|
|
1575
|
+
"""reshape tensor to dst shape"""
|
|
1576
|
+
np_tensor = tensor.asnumpy()
|
|
1577
|
+
np_tensor = np_tensor.reshape(dst_shape)
|
|
1578
|
+
return Tensor(np_tensor, tensor.dtype)
|
|
1579
|
+
|
|
1580
|
+
|
|
1581
|
+
def _check_param_for_integrate_save(pipeline_stages, uniform_split):
|
|
1582
|
+
"""check whether current settings and parameters are supported in integrated save checkpoint mode"""
|
|
1583
|
+
if pipeline_stages > 1:
|
|
1584
|
+
raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
|
|
1585
|
+
if uniform_split == 0:
|
|
1586
|
+
raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
|
|
1587
|
+
"'integrated_save' to True, the checkpoint will be integrated save, it "
|
|
1588
|
+
"is only supports uniform split tensor now.")
|
|
1589
|
+
|
|
1590
|
+
|
|
1497
1591
|
def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
|
|
1498
1592
|
"""
|
|
1499
1593
|
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
|
@@ -1507,7 +1601,7 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1507
1601
|
Tensor, the combined tensor which with the whole data value.
|
|
1508
1602
|
"""
|
|
1509
1603
|
layout = parameter_layout_dict[param_name]
|
|
1510
|
-
if len(layout) <
|
|
1604
|
+
if len(layout) < 8:
|
|
1511
1605
|
logger.info("The layout dict does not contain the key %s", param_name)
|
|
1512
1606
|
return param_data
|
|
1513
1607
|
|
|
@@ -1515,6 +1609,13 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1515
1609
|
tensor_map = layout[1]
|
|
1516
1610
|
uniform_split = layout[4]
|
|
1517
1611
|
opt_shard_group = layout[5]
|
|
1612
|
+
before_reshape_slice_shape = layout[2]
|
|
1613
|
+
before_reshape_full_shape = layout[6]
|
|
1614
|
+
after_reshape_slice_shape = layout[7]
|
|
1615
|
+
do_reshape = False
|
|
1616
|
+
if before_reshape_full_shape and after_reshape_slice_shape \
|
|
1617
|
+
and after_reshape_slice_shape != before_reshape_slice_shape:
|
|
1618
|
+
do_reshape = True
|
|
1518
1619
|
|
|
1519
1620
|
allgather_net = None
|
|
1520
1621
|
mp_weight = False
|
|
@@ -1527,26 +1628,26 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1527
1628
|
else:
|
|
1528
1629
|
logger.info("Need to create allgather net for %s", param_name)
|
|
1529
1630
|
if integrated_save:
|
|
1530
|
-
|
|
1531
|
-
raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
|
|
1532
|
-
if uniform_split == 0:
|
|
1533
|
-
raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
|
|
1534
|
-
"'integrated_save' to True, the checkpoint will be integrated save, it "
|
|
1535
|
-
"is only supports uniform split tensor now.")
|
|
1631
|
+
_check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split)
|
|
1536
1632
|
# while any dim is not equal to -1, means param is split and needs to be merged
|
|
1537
1633
|
# pipeline parallel need to be supported here later
|
|
1538
1634
|
if mp_weight:
|
|
1539
|
-
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group)
|
|
1635
|
+
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
|
|
1636
|
+
tuple(after_reshape_slice_shape))
|
|
1540
1637
|
object.__setattr__(allgather_net, "keep_input_unchanged", True)
|
|
1541
1638
|
elif opt_shard_group:
|
|
1542
|
-
allgather_net = get_allgather_cell(opt_shard_group, False
|
|
1639
|
+
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1640
|
+
tuple(after_reshape_slice_shape))
|
|
1543
1641
|
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
|
1544
|
-
allgather_net = get_allgather_cell(opt_shard_group, False
|
|
1642
|
+
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1643
|
+
tuple(after_reshape_slice_shape))
|
|
1545
1644
|
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
|
1546
1645
|
if allgather_net:
|
|
1547
1646
|
param_data = allgather_net(param_data)
|
|
1548
1647
|
if mp_weight and integrated_save:
|
|
1549
1648
|
param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
|
|
1649
|
+
if do_reshape:
|
|
1650
|
+
param_data = _reshape_tensor(param_data, before_reshape_full_shape)
|
|
1550
1651
|
return param_data
|
|
1551
1652
|
|
|
1552
1653
|
|
|
@@ -1556,7 +1657,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1556
1657
|
|
|
1557
1658
|
Note:
|
|
1558
1659
|
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
|
1559
|
-
2. When file_name does not have a suffix, the system will automatically add one
|
|
1660
|
+
2. When `file_name` does not have a suffix, the system will automatically add one
|
|
1661
|
+
according to the `file_format`.
|
|
1560
1662
|
3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
|
|
1561
1663
|
4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
|
|
1562
1664
|
class properties in calculations.
|
|
@@ -1586,9 +1688,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1586
1688
|
- For 'AIR' and 'ONNX' models, only customized encryption is supported.
|
|
1587
1689
|
- For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
|
|
1588
1690
|
or Customized encryption.
|
|
1589
|
-
Default: 'AES-GCM'
|
|
1691
|
+
Default: ``'AES-GCM'``.
|
|
1590
1692
|
- For details of using the customized encryption, please check the `tutorial
|
|
1591
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
1693
|
+
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
1592
1694
|
|
|
1593
1695
|
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
|
|
1594
1696
|
preprocessing of the dataset into MindIR.
|
|
@@ -1602,32 +1704,49 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1602
1704
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1603
1705
|
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1604
1706
|
Reference to 'my_func()' in
|
|
1605
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/
|
|
1707
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
1606
1708
|
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1607
1709
|
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1608
1710
|
obfuscated model.
|
|
1609
1711
|
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1610
1712
|
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1611
|
-
`obf_random_seed` is set, then it should be passed
|
|
1713
|
+
`obf_random_seed` is set, then it should be passed
|
|
1714
|
+
to :class:`mindspore.nn.GraphCell` interface when loading
|
|
1612
1715
|
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1613
1716
|
be set, and the latter mode would be applied if both of them are set.
|
|
1614
1717
|
|
|
1615
1718
|
- incremental (bool): export MindIR incrementally.
|
|
1616
1719
|
|
|
1720
|
+
- custom_func (function): Functions for custom defined export policies. This function will be used to
|
|
1721
|
+
customize the model during network export. Currently only support for files with mindir format. The
|
|
1722
|
+
function only accepts one input representing the proto object of the mindir file. When modifying a model,
|
|
1723
|
+
it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading
|
|
1724
|
+
failure or functional errors. Default: ``None`` .
|
|
1725
|
+
|
|
1617
1726
|
Examples:
|
|
1618
1727
|
>>> import mindspore as ms
|
|
1619
1728
|
>>> import numpy as np
|
|
1620
1729
|
>>> from mindspore import Tensor
|
|
1621
1730
|
>>>
|
|
1622
1731
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1623
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
1732
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1624
1733
|
>>> net = LeNet5()
|
|
1625
1734
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1626
1735
|
>>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
|
|
1736
|
+
>>>
|
|
1737
|
+
>>> # Export model in MindIR format and modified the model info using custom_func
|
|
1738
|
+
>>> # The custom_func only support one input representing the Proto object of the model
|
|
1739
|
+
>>> # And custom_func does not support return value
|
|
1740
|
+
>>> def _custom_func(mindir_model):
|
|
1741
|
+
... mindir_model.producer_name = "test11111"
|
|
1742
|
+
... mindir_model.producer_version = "11.0"
|
|
1743
|
+
... mindir_model.user_info["version"] = "11.0"
|
|
1744
|
+
>>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func)
|
|
1745
|
+
|
|
1627
1746
|
|
|
1628
1747
|
Tutorial Examples:
|
|
1629
1748
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
1630
|
-
<https://mindspore.cn/tutorials/en/
|
|
1749
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
1631
1750
|
"""
|
|
1632
1751
|
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
1633
1752
|
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
@@ -1690,7 +1809,7 @@ def _get_funcgraph(net, *inputs):
|
|
|
1690
1809
|
>>> from mindspore import Tensor
|
|
1691
1810
|
>>>
|
|
1692
1811
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1693
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
1812
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1694
1813
|
>>> net = LeNet5()
|
|
1695
1814
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1696
1815
|
>>> ms.get_funcgraph(net, input_tensor)
|
|
@@ -1712,6 +1831,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
1712
1831
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
1713
1832
|
if "obf_config" in kwargs and file_format != "MINDIR":
|
|
1714
1833
|
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
1834
|
+
if "custom_func" in kwargs and file_format != "MINDIR":
|
|
1835
|
+
raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
|
|
1715
1836
|
if file_format == 'AIR':
|
|
1716
1837
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
1717
1838
|
elif file_format == 'ONNX':
|
|
@@ -1872,12 +1993,12 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1872
1993
|
data_file_name = os.path.join(dirname, external_local)
|
|
1873
1994
|
f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
|
|
1874
1995
|
try:
|
|
1875
|
-
|
|
1996
|
+
round = 0
|
|
1876
1997
|
names = []
|
|
1877
1998
|
for param_proto in model.graph.parameter:
|
|
1878
1999
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1879
2000
|
names.append((name, param_proto))
|
|
1880
|
-
|
|
2001
|
+
names.sort(key=lambda x: x[0])
|
|
1881
2002
|
for pairs in names:
|
|
1882
2003
|
name = pairs[0]
|
|
1883
2004
|
param_proto = pairs[1]
|
|
@@ -1900,8 +2021,8 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1900
2021
|
offset += (data_length + append_size)
|
|
1901
2022
|
write_data = _encrypt_data(is_encrypt, write_data, kwargs)
|
|
1902
2023
|
f.write(write_data)
|
|
1903
|
-
|
|
1904
|
-
logger.debug(f"writing {
|
|
2024
|
+
round += 1
|
|
2025
|
+
logger.debug(f"writing {round}th split data, name:{name}")
|
|
1905
2026
|
|
|
1906
2027
|
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
|
1907
2028
|
if os.path.exists(graph_file_name):
|
|
@@ -1998,6 +2119,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
|
1998
2119
|
dataset = kwargs.get('dataset')
|
|
1999
2120
|
_save_dataset_to_mindir(model, dataset)
|
|
2000
2121
|
|
|
2122
|
+
custom_func = kwargs.get('custom_func', None)
|
|
2123
|
+
if custom_func is not None:
|
|
2124
|
+
custom_func(model)
|
|
2125
|
+
|
|
2001
2126
|
save_together = _save_together(net_dict, model)
|
|
2002
2127
|
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
|
|
2003
2128
|
if save_together:
|
|
@@ -2428,7 +2553,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2428
2553
|
in at least one of them. Default: ``None`` .
|
|
2429
2554
|
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
2430
2555
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
2431
|
-
parameter in the network. When the types are inconsistent perform type conversion
|
|
2556
|
+
parameter in the network. When the types are inconsistent, perform type conversion
|
|
2432
2557
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
2433
2558
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
2434
2559
|
is not required. Default: ``None`` .
|
|
@@ -2449,14 +2574,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2449
2574
|
|
|
2450
2575
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2451
2576
|
Please see the `rank table startup
|
|
2452
|
-
<https://www.mindspore.cn/tutorials/experts/en/
|
|
2577
|
+
<https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
|
|
2453
2578
|
for more details.
|
|
2454
2579
|
|
|
2455
2580
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2456
|
-
<https://www.mindspore.cn/tutorials/experts/en/
|
|
2581
|
+
<https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
|
|
2457
2582
|
|
|
2458
2583
|
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2459
|
-
Startup <https://www.mindspore.cn/tutorials/experts/en/
|
|
2584
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
|
|
2460
2585
|
|
|
2461
2586
|
>>> import os
|
|
2462
2587
|
>>> import numpy as np
|
|
@@ -2722,11 +2847,10 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
2722
2847
|
param_name = merged_param.name
|
|
2723
2848
|
tensor_layout = predict_strategy[param_name]
|
|
2724
2849
|
rank = get_rank()
|
|
2725
|
-
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
|
|
2850
|
+
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
|
|
2726
2851
|
requires_grad = merged_param.requires_grad
|
|
2727
2852
|
layerwise_parallel = merged_param.layerwise_parallel
|
|
2728
|
-
|
|
2729
|
-
if data_type == mstype.bfloat16:
|
|
2853
|
+
if merged_param.data.dtype == mstype.bfloat16:
|
|
2730
2854
|
split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
|
|
2731
2855
|
else:
|
|
2732
2856
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
@@ -2794,7 +2918,7 @@ def _get_mindir_inputs(file_name):
|
|
|
2794
2918
|
|
|
2795
2919
|
def convert_model(mindir_file, convert_file, file_format):
|
|
2796
2920
|
"""
|
|
2797
|
-
Convert mindir model to other format model.
|
|
2921
|
+
Convert mindir model to other format model. The current version only supports conversion to ONNX models.
|
|
2798
2922
|
|
|
2799
2923
|
.. warning::
|
|
2800
2924
|
This is an experimental API that is subject to change or deletion.
|