mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -23,7 +23,7 @@ import os
|
|
|
23
23
|
import shutil
|
|
24
24
|
import stat
|
|
25
25
|
import threading
|
|
26
|
-
from threading import Thread,
|
|
26
|
+
from threading import Thread, RLock
|
|
27
27
|
from collections import defaultdict, OrderedDict
|
|
28
28
|
from io import BytesIO
|
|
29
29
|
|
|
@@ -59,19 +59,23 @@ from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_
|
|
|
59
59
|
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
|
|
60
60
|
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
61
61
|
_restore_group_info_list
|
|
62
|
+
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
63
|
+
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
62
64
|
from mindspore.train._utils import read_proto
|
|
63
65
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
64
|
-
split_mindir
|
|
66
|
+
split_mindir, split_dynamic_mindir
|
|
65
67
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
68
|
+
from ..ops.operations import Cast
|
|
66
69
|
|
|
67
70
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
68
71
|
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
|
69
72
|
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
|
70
|
-
"Bool": mstype.bool_, "str": mstype.string}
|
|
73
|
+
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16}
|
|
71
74
|
|
|
72
75
|
tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
|
|
73
76
|
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
74
|
-
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"
|
|
77
|
+
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U",
|
|
78
|
+
"BFloat16": np.float32}
|
|
75
79
|
|
|
76
80
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
77
81
|
|
|
@@ -79,7 +83,7 @@ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4:
|
|
|
79
83
|
5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
|
|
80
84
|
11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
|
|
81
85
|
|
|
82
|
-
_ckpt_mutex =
|
|
86
|
+
_ckpt_mutex = RLock()
|
|
83
87
|
|
|
84
88
|
# unit is KB
|
|
85
89
|
SLICE_SIZE = 512 * 1024
|
|
@@ -89,6 +93,8 @@ PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
|
|
|
89
93
|
ENCRYPT_BLOCK_SIZE = 64 * 1024
|
|
90
94
|
INT_64_MAX = 9223372036854775807
|
|
91
95
|
|
|
96
|
+
cpu_cast = Cast().set_device("CPU")
|
|
97
|
+
|
|
92
98
|
|
|
93
99
|
def _special_process_par(par, new_par):
|
|
94
100
|
"""
|
|
@@ -105,7 +111,11 @@ def _special_process_par(par, new_par):
|
|
|
105
111
|
if new_par.data.shape[par_shape_len + i] != 1:
|
|
106
112
|
return False
|
|
107
113
|
|
|
108
|
-
|
|
114
|
+
if new_par.data.dtype == mstype.bfloat16:
|
|
115
|
+
new_val = cpu_cast(new_par.data, mstype.float32).asnumpy()
|
|
116
|
+
else:
|
|
117
|
+
new_val = new_par.data.asnumpy()
|
|
118
|
+
|
|
109
119
|
new_val = new_val.reshape(par.data.shape)
|
|
110
120
|
par.set_data(Tensor(new_val, par.data.dtype))
|
|
111
121
|
return True
|
|
@@ -126,7 +136,10 @@ def _update_param(param, new_param, strict_load):
|
|
|
126
136
|
|
|
127
137
|
if param.data.dtype != new_param.data.dtype:
|
|
128
138
|
if _type_convert(param, new_param, strict_load):
|
|
129
|
-
|
|
139
|
+
if new_param.data.dtype == mstype.bfloat16:
|
|
140
|
+
new_tensor = cpu_cast(new_param.data, param.data.dtype)
|
|
141
|
+
else:
|
|
142
|
+
new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
|
|
130
143
|
param.set_data(new_tensor, param.sliced)
|
|
131
144
|
return
|
|
132
145
|
|
|
@@ -229,10 +242,16 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
229
242
|
continue
|
|
230
243
|
if value[0] == "offload_parameter":
|
|
231
244
|
new_value = value[1:]
|
|
232
|
-
|
|
245
|
+
if value[3].dtype == mstype.bfloat16:
|
|
246
|
+
new_value[2] = cpu_cast(value[3], mstype.float32).asnumpy().reshape(-1)
|
|
247
|
+
else:
|
|
248
|
+
new_value[2] = value[3].asnumpy().reshape(-1)
|
|
233
249
|
_write_parameter_data(name, new_value, f, enc_key, plain_data)
|
|
234
250
|
_offload_if_config(value[3])
|
|
235
251
|
continue
|
|
252
|
+
if value[0] == "BFloat16_tensor":
|
|
253
|
+
_write_bfloat16_data(name, value, f, enc_key, plain_data)
|
|
254
|
+
continue
|
|
236
255
|
if isinstance(value[2], Tensor):
|
|
237
256
|
_write_hugeparameter(name, value, f)
|
|
238
257
|
continue
|
|
@@ -267,6 +286,21 @@ def _write_random_seed(name, value, f):
|
|
|
267
286
|
f.write(checkpoint_list.SerializeToString())
|
|
268
287
|
|
|
269
288
|
|
|
289
|
+
def _write_bfloat16_data(name, value, f, enc_key, plain_data):
|
|
290
|
+
"""Write bfloat16 data into protobuf file"""
|
|
291
|
+
checkpoint_list = Checkpoint()
|
|
292
|
+
param_value = checkpoint_list.value.add()
|
|
293
|
+
param_value.tag = name
|
|
294
|
+
param_tensor = param_value.tensor
|
|
295
|
+
param_tensor.dims.extend(value[1])
|
|
296
|
+
param_tensor.tensor_type = value[2]
|
|
297
|
+
param_tensor.tensor_content = value[3].get_bytes()
|
|
298
|
+
if enc_key is None:
|
|
299
|
+
f.write(checkpoint_list.SerializeToString())
|
|
300
|
+
else:
|
|
301
|
+
plain_data.write(checkpoint_list.SerializeToString())
|
|
302
|
+
|
|
303
|
+
|
|
270
304
|
def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
271
305
|
"""Write parameter data into protobuf file."""
|
|
272
306
|
data_size = value[2].nbytes / 1024
|
|
@@ -333,8 +367,8 @@ def _write_hugeparameter(name, value, f):
|
|
|
333
367
|
|
|
334
368
|
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
335
369
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
336
|
-
if not isinstance(save_obj, nn.Cell
|
|
337
|
-
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell or
|
|
370
|
+
if not isinstance(save_obj, (nn.Cell, list, dict)):
|
|
371
|
+
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
|
|
338
372
|
"but got {}.".format(type(save_obj)))
|
|
339
373
|
if not isinstance(ckpt_file_name, str):
|
|
340
374
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
@@ -351,14 +385,15 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
351
385
|
|
|
352
386
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
353
387
|
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
|
|
354
|
-
"""
|
|
388
|
+
r"""
|
|
355
389
|
Save checkpoint to a specified file.
|
|
356
390
|
|
|
357
391
|
Args:
|
|
358
|
-
save_obj (Union[Cell, list]): The
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
392
|
+
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
393
|
+
list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
394
|
+
elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
|
|
395
|
+
`param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
|
|
396
|
+
it can be the returned value of `mindspore.load_checkpoint()`.
|
|
362
397
|
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
|
363
398
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
|
|
364
399
|
async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
|
|
@@ -370,16 +405,14 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
370
405
|
mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
371
406
|
Default: ``"AES-GCM"`` .
|
|
372
407
|
choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
|
|
373
|
-
a parameter name in string type, and the
|
|
408
|
+
a parameter name in string type, and the returned value is a bool.
|
|
374
409
|
If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
375
410
|
If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
376
411
|
be saved. Default: ``None`` .
|
|
377
412
|
kwargs (dict): Configuration options dictionary.
|
|
378
413
|
|
|
379
|
-
- incremental (bool): Whether export checkpoint for MapParameter incrementally.
|
|
380
|
-
|
|
381
414
|
Raises:
|
|
382
|
-
TypeError: If the parameter `save_obj` is not
|
|
415
|
+
TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
|
|
383
416
|
TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
|
|
384
417
|
TypeError: If the parameter `ckpt_file_name` is not string type.
|
|
385
418
|
|
|
@@ -387,17 +420,27 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
387
420
|
>>> import mindspore as ms
|
|
388
421
|
>>>
|
|
389
422
|
>>> # Define the network structure of LeNet5. Refer to
|
|
390
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
423
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
391
424
|
>>> net = LeNet5()
|
|
392
425
|
>>> ms.save_checkpoint(net, "./lenet.ckpt",
|
|
393
|
-
|
|
394
|
-
>>>
|
|
395
|
-
>>> print(
|
|
426
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
427
|
+
>>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
|
|
428
|
+
>>> print(param_dict1)
|
|
396
429
|
{'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
|
|
430
|
+
>>> params_list = net.trainable_params()
|
|
431
|
+
>>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
|
|
432
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
|
|
433
|
+
>>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
|
|
434
|
+
>>> print(param_dict2)
|
|
435
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
436
|
+
>>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
|
|
437
|
+
>>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
|
|
438
|
+
>>> print(param_dict3)
|
|
439
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
397
440
|
|
|
398
441
|
Tutorial Examples:
|
|
399
442
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
400
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
443
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
401
444
|
"""
|
|
402
445
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
403
446
|
integrated_save = Validator.check_bool(integrated_save)
|
|
@@ -408,70 +451,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
408
451
|
map_param_inc = kwargs.get('incremental', False)
|
|
409
452
|
logger.info("Execute the process of saving checkpoint files.")
|
|
410
453
|
|
|
411
|
-
|
|
412
|
-
if save_obj.ge_init and not save_obj.ge_sync_data:
|
|
413
|
-
from mindspore.train.callback._callback import set_cur_net
|
|
414
|
-
set_cur_net(save_obj)
|
|
415
|
-
save_obj.exec_checkpoint_graph()
|
|
416
|
-
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
417
|
-
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
418
|
-
parameter_layout_dict = _get_parameter_layout()
|
|
419
|
-
if not _is_in_auto_parallel_mode():
|
|
420
|
-
save_obj.init_parameters_data()
|
|
421
|
-
param_dict = OrderedDict()
|
|
422
|
-
for _, param in save_obj.parameters_and_names():
|
|
423
|
-
not_sliced = not param.sliced
|
|
424
|
-
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
425
|
-
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
426
|
-
if is_graph_mode and _is_in_auto_parallel_mode() and (not_sliced or param.has_init):
|
|
427
|
-
continue
|
|
428
|
-
if choice_func is not None and not choice_func(param.name):
|
|
429
|
-
continue
|
|
430
|
-
param_dict[param.name] = param
|
|
431
|
-
param_list = []
|
|
432
|
-
if append_dict and "random_op" in append_dict:
|
|
433
|
-
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
434
|
-
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
435
|
-
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
436
|
-
param_list.append({"name": "random_op", "data": random_byte})
|
|
437
|
-
append_dict.pop("random_op")
|
|
438
|
-
for (key, value) in param_dict.items():
|
|
439
|
-
each_param = {"name": key}
|
|
440
|
-
if isinstance(value, MapParameter):
|
|
441
|
-
each_param["data"] = value
|
|
442
|
-
param_list.append(each_param)
|
|
443
|
-
continue
|
|
444
|
-
|
|
445
|
-
if value.data.is_persistent_data():
|
|
446
|
-
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
447
|
-
param_data = ["persistent_data"]
|
|
448
|
-
param_data.append(value.data)
|
|
449
|
-
param_data.append(value.param_info.origin_shape)
|
|
450
|
-
param_data.append(str(value.dtype))
|
|
451
|
-
param_data.append(value.key)
|
|
452
|
-
elif value.data.offload_file_path() != "":
|
|
453
|
-
# list save offload data: [Param, shape, type, param.key]
|
|
454
|
-
param_data = ["offload_parameter"]
|
|
455
|
-
param_tensor = value.data
|
|
456
|
-
if key in parameter_layout_dict:
|
|
457
|
-
param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
|
|
458
|
-
integrated_save)
|
|
459
|
-
param_data.append(param_tensor)
|
|
460
|
-
param_data.append(param_tensor.shape)
|
|
461
|
-
param_data.append(str(param_tensor.dtype))
|
|
462
|
-
param_data.append(value.key)
|
|
463
|
-
else:
|
|
464
|
-
param_data = Tensor(value.data.asnumpy())
|
|
465
|
-
|
|
466
|
-
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
467
|
-
# which should be combined before saving
|
|
468
|
-
if key in parameter_layout_dict:
|
|
469
|
-
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
470
|
-
integrated_save)
|
|
471
|
-
|
|
472
|
-
each_param["data"] = param_data
|
|
473
|
-
param_list.append(each_param)
|
|
474
|
-
save_obj = param_list
|
|
454
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
475
455
|
|
|
476
456
|
if append_dict:
|
|
477
457
|
append_info_list = []
|
|
@@ -479,7 +459,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
479
459
|
if not isinstance(value, str):
|
|
480
460
|
value = Tensor(value)
|
|
481
461
|
append_info_list.append({"name": k_name, "data": value})
|
|
482
|
-
|
|
462
|
+
save_obj.extend(append_info_list)
|
|
483
463
|
|
|
484
464
|
data_list = OrderedDict()
|
|
485
465
|
with _ckpt_mutex:
|
|
@@ -499,6 +479,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
499
479
|
elif param["data"][0] == "offload_parameter":
|
|
500
480
|
data_list[key].append("offload_parameter")
|
|
501
481
|
_save_param_list_data(data_list, key, param)
|
|
482
|
+
elif param["data"][0] == "BFloat16_tensor":
|
|
483
|
+
data_list[key].append("BFloat16_tensor")
|
|
484
|
+
_save_param_list_data(data_list, key, param)
|
|
485
|
+
continue
|
|
502
486
|
|
|
503
487
|
if isinstance(param["data"], str):
|
|
504
488
|
data_list[key].append([0])
|
|
@@ -508,6 +492,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
508
492
|
else:
|
|
509
493
|
if isinstance(param["data"], Parameter):
|
|
510
494
|
param["data"].init_data()
|
|
495
|
+
if isinstance(param["data"], Tensor) and param["data"].dtype == mstype.bfloat16:
|
|
496
|
+
data_list[key].append("BFloat16_tensor")
|
|
497
|
+
dims = []
|
|
498
|
+
for dim in param["data"].shape:
|
|
499
|
+
dims.append(dim)
|
|
500
|
+
data_list[key].append(dims)
|
|
501
|
+
data_list[key].append("BFloat16")
|
|
502
|
+
data_list[key].append(cpu_cast(param["data"], mstype.float32))
|
|
503
|
+
continue
|
|
511
504
|
dims = []
|
|
512
505
|
if param['data'].shape == ():
|
|
513
506
|
dims.append(0)
|
|
@@ -517,7 +510,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
517
510
|
data_list[key].append(dims)
|
|
518
511
|
tensor_type = str(param["data"].dtype)
|
|
519
512
|
data_list[key].append(tensor_type)
|
|
520
|
-
|
|
513
|
+
if param["data"].dtype == mstype.bfloat16:
|
|
514
|
+
data = cpu_cast(param["data"], mstype.float32).asnumpy().reshape(-1)
|
|
515
|
+
else:
|
|
516
|
+
data = param["data"].asnumpy().reshape(-1)
|
|
521
517
|
data_list[key].append(data)
|
|
522
518
|
|
|
523
519
|
if async_save:
|
|
@@ -530,6 +526,130 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
530
526
|
logger.info("Saving checkpoint process is finished.")
|
|
531
527
|
|
|
532
528
|
|
|
529
|
+
def _convert_list_to_param_list(save_obj, choice_func):
|
|
530
|
+
"""Convert a list of Parameter to param_list."""
|
|
531
|
+
param_list = []
|
|
532
|
+
if not save_obj:
|
|
533
|
+
return param_list
|
|
534
|
+
if isinstance(save_obj[0], dict):
|
|
535
|
+
param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
|
|
536
|
+
else:
|
|
537
|
+
for param in save_obj:
|
|
538
|
+
if isinstance(param, Parameter):
|
|
539
|
+
if choice_func is not None and not choice_func(param.name):
|
|
540
|
+
continue
|
|
541
|
+
each_param = {"name": param.name, "data": param}
|
|
542
|
+
param_list.append(each_param)
|
|
543
|
+
else:
|
|
544
|
+
raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
|
|
545
|
+
f"the param should be parameter, but got {type(param)}")
|
|
546
|
+
return param_list
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
550
|
+
"""Convert a dict of Parameter to param_list."""
|
|
551
|
+
param_list = []
|
|
552
|
+
for (key, value) in save_obj.items():
|
|
553
|
+
if isinstance(key, str) and isinstance(value, (Parameter, str)):
|
|
554
|
+
if choice_func is not None and not choice_func(key):
|
|
555
|
+
continue
|
|
556
|
+
each_param = {"name": key, "data": value}
|
|
557
|
+
param_list.append(each_param)
|
|
558
|
+
else:
|
|
559
|
+
raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
|
|
560
|
+
f"value should be Parameter, but got the type of key is {type(key)} and"
|
|
561
|
+
f"the type of value is {type(value)}")
|
|
562
|
+
return param_list
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
566
|
+
"""Convert cell.parameters_and_names to OrderedDict."""
|
|
567
|
+
param_dict = OrderedDict()
|
|
568
|
+
for _, param in save_obj.parameters_and_names():
|
|
569
|
+
not_sliced = not param.sliced
|
|
570
|
+
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
571
|
+
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
572
|
+
judgment = not_sliced or param.has_init
|
|
573
|
+
if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
|
|
574
|
+
continue
|
|
575
|
+
if choice_func is not None and not choice_func(param.name):
|
|
576
|
+
continue
|
|
577
|
+
# Add suffix for cache_enabled parameter, and then parameter can carry key info.
|
|
578
|
+
# Notice that suffix needs be removed when loading into net.
|
|
579
|
+
if param.cache_enable:
|
|
580
|
+
param_dict[param.name + ".__param_key__" + str(param.key)] = param
|
|
581
|
+
else:
|
|
582
|
+
param_dict[param.name] = param
|
|
583
|
+
return param_dict
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
587
|
+
"""Convert nn.Cell to param_list."""
|
|
588
|
+
param_list = []
|
|
589
|
+
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
590
|
+
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
591
|
+
parameter_layout_dict = _get_parameter_layout()
|
|
592
|
+
if not _is_in_auto_parallel_mode():
|
|
593
|
+
save_obj.init_parameters_data()
|
|
594
|
+
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
|
|
595
|
+
if append_dict and "random_op" in append_dict:
|
|
596
|
+
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
597
|
+
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
598
|
+
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
599
|
+
param_list.append({"name": "random_op", "data": random_byte})
|
|
600
|
+
append_dict.pop("random_op")
|
|
601
|
+
for (key, value) in param_dict.items():
|
|
602
|
+
each_param = {"name": key}
|
|
603
|
+
if isinstance(value, MapParameter):
|
|
604
|
+
each_param["data"] = value
|
|
605
|
+
param_list.append(each_param)
|
|
606
|
+
continue
|
|
607
|
+
|
|
608
|
+
if value.data.is_persistent_data():
|
|
609
|
+
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
610
|
+
param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
|
|
611
|
+
elif value.data.offload_file_path() != "":
|
|
612
|
+
# list save offload data: [Param, shape, type, param.key]
|
|
613
|
+
param_data = ["offload_parameter"]
|
|
614
|
+
param_tensor = value.data
|
|
615
|
+
if key in parameter_layout_dict:
|
|
616
|
+
param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
|
|
617
|
+
integrated_save)
|
|
618
|
+
param_data.append(param_tensor)
|
|
619
|
+
param_data.append(param_tensor.shape)
|
|
620
|
+
param_data.append(str(param_tensor.dtype))
|
|
621
|
+
param_data.append(value.key)
|
|
622
|
+
elif value.data.dtype == mstype.bfloat16:
|
|
623
|
+
param_data = ["BFloat16_tensor"]
|
|
624
|
+
param_data.append(cpu_cast(value.data, mstype.float32))
|
|
625
|
+
param_data.append(value.data.shape)
|
|
626
|
+
param_data.append("BFloat16")
|
|
627
|
+
param_data.append(value.key)
|
|
628
|
+
else:
|
|
629
|
+
param_data = Tensor(value.data.asnumpy())
|
|
630
|
+
|
|
631
|
+
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
632
|
+
# which should be combined before saving
|
|
633
|
+
if key in parameter_layout_dict:
|
|
634
|
+
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
635
|
+
integrated_save)
|
|
636
|
+
|
|
637
|
+
each_param["data"] = param_data
|
|
638
|
+
param_list.append(each_param)
|
|
639
|
+
return param_list
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
643
|
+
"""Convert a save_obj to param_list."""
|
|
644
|
+
if isinstance(save_obj, list):
|
|
645
|
+
return _convert_list_to_param_list(save_obj, choice_func)
|
|
646
|
+
|
|
647
|
+
if isinstance(save_obj, dict):
|
|
648
|
+
return _convert_dict_to_param_dict(save_obj, choice_func)
|
|
649
|
+
|
|
650
|
+
return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
651
|
+
|
|
652
|
+
|
|
533
653
|
def _save_param_list_data(data_list, key, param):
|
|
534
654
|
"""Save persistent data into save_obj."""
|
|
535
655
|
dims = []
|
|
@@ -585,7 +705,7 @@ def load(file_name, **kwargs):
|
|
|
585
705
|
|
|
586
706
|
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
587
707
|
`obfuscate_model()
|
|
588
|
-
<https://www.mindspore.cn/docs/en/r2.
|
|
708
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
589
709
|
|
|
590
710
|
Returns:
|
|
591
711
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
@@ -615,7 +735,7 @@ def load(file_name, **kwargs):
|
|
|
615
735
|
|
|
616
736
|
Tutorial Examples:
|
|
617
737
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
618
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
738
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
619
739
|
"""
|
|
620
740
|
if not isinstance(file_name, str):
|
|
621
741
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
@@ -656,7 +776,7 @@ def load(file_name, **kwargs):
|
|
|
656
776
|
return graph
|
|
657
777
|
|
|
658
778
|
|
|
659
|
-
def export_split_mindir(file_name):
|
|
779
|
+
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
|
|
660
780
|
"""
|
|
661
781
|
Auto Split MindIR.
|
|
662
782
|
|
|
@@ -664,6 +784,10 @@ def export_split_mindir(file_name):
|
|
|
664
784
|
|
|
665
785
|
Args:
|
|
666
786
|
file_name (str): MindIR file name.
|
|
787
|
+
device_num (int): device number.
|
|
788
|
+
rank_id (int): rank id.
|
|
789
|
+
dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
|
|
790
|
+
sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
|
|
667
791
|
|
|
668
792
|
Raises:
|
|
669
793
|
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
@@ -671,11 +795,9 @@ def export_split_mindir(file_name):
|
|
|
671
795
|
|
|
672
796
|
Examples:
|
|
673
797
|
>>> import mindspore as ms
|
|
674
|
-
>>> from mindspore.communication import init
|
|
675
798
|
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
676
799
|
>>>
|
|
677
|
-
>>>
|
|
678
|
-
>>> ms.export_split_mindir("net.mindir")
|
|
800
|
+
>>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
|
|
679
801
|
|
|
680
802
|
"""
|
|
681
803
|
if not isinstance(file_name, str):
|
|
@@ -690,8 +812,11 @@ def export_split_mindir(file_name):
|
|
|
690
812
|
file_name = os.path.abspath(file_name)
|
|
691
813
|
|
|
692
814
|
logger.info("Execute the process of export and split mindir.")
|
|
693
|
-
|
|
694
|
-
|
|
815
|
+
dynamic = True
|
|
816
|
+
if dynamic:
|
|
817
|
+
graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
|
|
818
|
+
else:
|
|
819
|
+
graph = split_mindir(file_name)
|
|
695
820
|
|
|
696
821
|
if graph is None:
|
|
697
822
|
if _is_cipher_file(file_name):
|
|
@@ -779,17 +904,20 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
779
904
|
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
|
780
905
|
is the same as using :func:`mindspore.export`.
|
|
781
906
|
- obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
782
|
-
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
907
|
+
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
908
|
+
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
783
909
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
784
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean
|
|
785
|
-
|
|
910
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
911
|
+
Reference to 'my_func()' in
|
|
912
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
|
|
913
|
+
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
786
914
|
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
787
915
|
when loading obfuscated model.
|
|
788
|
-
- obf_random_seed (int):
|
|
789
|
-
|
|
790
|
-
then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
791
|
-
noted that at least one of `customized_func` or `obf_random_seed` should
|
|
792
|
-
would be applied if both of them are set.
|
|
916
|
+
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
917
|
+
structure of obfuscated models corresponding to different random seeds is different. If
|
|
918
|
+
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
919
|
+
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
920
|
+
be set, and the latter mode would be applied if both of them are set.
|
|
793
921
|
|
|
794
922
|
kwargs (dict): Configuration options dictionary.
|
|
795
923
|
|
|
@@ -928,27 +1056,27 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
928
1056
|
>>> print(param_dict["conv2.weight"])
|
|
929
1057
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
930
1058
|
>>> def func(param_name):
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
1059
|
+
... whether_load = False
|
|
1060
|
+
... if param_name.startswith("conv"):
|
|
1061
|
+
... whether_load = True
|
|
1062
|
+
... if param_name.startswith("conv1"):
|
|
1063
|
+
... whether_load = False
|
|
1064
|
+
... return whether_load
|
|
937
1065
|
>>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
938
1066
|
>>> print(param_dict1["conv2.weight"])
|
|
939
1067
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
940
1068
|
>>> def func(param_name):
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
1069
|
+
... whether_load = False
|
|
1070
|
+
... if param_name.startswith("conv1"):
|
|
1071
|
+
... whether_load = True
|
|
1072
|
+
... return whether_load
|
|
945
1073
|
>>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
946
1074
|
>>> print(param_dict2)
|
|
947
1075
|
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
948
1076
|
|
|
949
1077
|
Tutorial Examples:
|
|
950
1078
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
951
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1079
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
952
1080
|
"""
|
|
953
1081
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
954
1082
|
specify_prefix = _check_prefix(specify_prefix)
|
|
@@ -979,8 +1107,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
979
1107
|
choice_func is not None and not choice_func(element.tag):
|
|
980
1108
|
continue
|
|
981
1109
|
if element.tensor.ByteSize() == 0:
|
|
982
|
-
_load_map_parameter(checkpoint_list, element, element_id,
|
|
983
|
-
map_data_list, map_shape_list, parameter_dict)
|
|
1110
|
+
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
|
|
984
1111
|
if element.tag in parameter_dict:
|
|
985
1112
|
map_data_list = [[], [], []]
|
|
986
1113
|
map_shape_list = [0, 0, 0]
|
|
@@ -992,6 +1119,13 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
992
1119
|
if data_type == 'str':
|
|
993
1120
|
str_length = int(len(data) / 4)
|
|
994
1121
|
np_type = np_type + str(str_length)
|
|
1122
|
+
if data_type == "BFloat16":
|
|
1123
|
+
dims = element.tensor.dims
|
|
1124
|
+
param_data = np.frombuffer(data, np_type)
|
|
1125
|
+
param_data = param_data.reshape(list(dims))
|
|
1126
|
+
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1127
|
+
parameter_dict[element.tag] = parameter
|
|
1128
|
+
continue
|
|
995
1129
|
element_data = np.frombuffer(data, np_type)
|
|
996
1130
|
param_data_list.append(element_data)
|
|
997
1131
|
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
@@ -1024,8 +1158,12 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1024
1158
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
1025
1159
|
f"'filter_prefix' or 'specify_prefix' are set correctly.")
|
|
1026
1160
|
|
|
1161
|
+
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1162
|
+
(is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
|
|
1027
1163
|
if net is not None:
|
|
1028
1164
|
load_param_into_net(net, parameter_dict, strict_load)
|
|
1165
|
+
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1166
|
+
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
1029
1167
|
|
|
1030
1168
|
return parameter_dict
|
|
1031
1169
|
|
|
@@ -1061,7 +1199,7 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
|
1061
1199
|
|
|
1062
1200
|
|
|
1063
1201
|
def _check_ckpt_file_name(ckpt_file_name):
|
|
1064
|
-
"""Check function load_checkpoint's
|
|
1202
|
+
"""Check function load_checkpoint's ckpt_file_name."""
|
|
1065
1203
|
if not isinstance(ckpt_file_name, str):
|
|
1066
1204
|
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
|
1067
1205
|
"but got {}.".format(type(ckpt_file_name)))
|
|
@@ -1175,7 +1313,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1175
1313
|
>>> import mindspore as ms
|
|
1176
1314
|
>>>
|
|
1177
1315
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1178
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
1316
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1179
1317
|
>>> net = LeNet5()
|
|
1180
1318
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1181
1319
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
@@ -1185,7 +1323,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1185
1323
|
|
|
1186
1324
|
Tutorial Examples:
|
|
1187
1325
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1188
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1326
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1189
1327
|
"""
|
|
1190
1328
|
if not isinstance(net, nn.Cell):
|
|
1191
1329
|
logger.critical("Failed to combine the net and the parameters.")
|
|
@@ -1219,6 +1357,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1219
1357
|
if isinstance(param, MapParameter):
|
|
1220
1358
|
param.import_data(parameter_dict[param.name])
|
|
1221
1359
|
continue
|
|
1360
|
+
# Add has attr protection when load server checkpoint file on worker.
|
|
1361
|
+
if not hasattr(parameter_dict[param.name], "data"):
|
|
1362
|
+
continue
|
|
1222
1363
|
new_param = copy.deepcopy(parameter_dict[param.name])
|
|
1223
1364
|
_update_param(param, new_param, strict_load)
|
|
1224
1365
|
ckpt_not_load.remove(param.name)
|
|
@@ -1243,6 +1384,72 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1243
1384
|
return param_not_load, ckpt_not_load
|
|
1244
1385
|
|
|
1245
1386
|
|
|
1387
|
+
def _warm_up_host_cache_enabled(parameter_dict):
|
|
1388
|
+
"""Warm up host cache enabled."""
|
|
1389
|
+
if _cache_enable():
|
|
1390
|
+
return True
|
|
1391
|
+
for key in parameter_dict.keys():
|
|
1392
|
+
if key.find(".__param_key__") != -1:
|
|
1393
|
+
return True
|
|
1394
|
+
return False
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
def _warm_up_host_cache(parameter_dict, net):
|
|
1398
|
+
"""Warm up host cache."""
|
|
1399
|
+
ms_role = os.getenv("MS_ROLE")
|
|
1400
|
+
is_worker = ms_role == "MS_WORKER"
|
|
1401
|
+
param_key_dict = {}
|
|
1402
|
+
# Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
|
|
1403
|
+
if is_worker:
|
|
1404
|
+
net.init_parameters_data()
|
|
1405
|
+
net_dict = {}
|
|
1406
|
+
for name, value in net.parameters_and_names():
|
|
1407
|
+
net_dict[name] = value
|
|
1408
|
+
for param_name, value in parameter_dict.items():
|
|
1409
|
+
pos = param_name.find(".__param_key__")
|
|
1410
|
+
if pos != -1:
|
|
1411
|
+
net_param_name = param_name[:pos]
|
|
1412
|
+
param_key_dict[param_name] = net_param_name
|
|
1413
|
+
net_value = None
|
|
1414
|
+
if net_param_name not in net_dict:
|
|
1415
|
+
logger.warning("net param name : %s is not in net", net_param_name)
|
|
1416
|
+
else:
|
|
1417
|
+
net_value = net_dict.get(net_param_name, None)
|
|
1418
|
+
pos += len(".__param_key__")
|
|
1419
|
+
param_key = int(param_name[pos:])
|
|
1420
|
+
value_is_map_parameter = isinstance(value, list) and len(value) == 3
|
|
1421
|
+
if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
|
|
1422
|
+
key_tensor = Tensor.from_numpy(value[0])
|
|
1423
|
+
value_tensor = Tensor.from_numpy(value[1])
|
|
1424
|
+
status_tensor = Tensor.from_numpy(value[2])
|
|
1425
|
+
_store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
|
|
1426
|
+
elif not isinstance(value, list) and isinstance(net_value, Parameter):
|
|
1427
|
+
_store_warm_up_ptr_by_tensor(param_key, value)
|
|
1428
|
+
else:
|
|
1429
|
+
logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
|
|
1430
|
+
else:
|
|
1431
|
+
for param_name, value in parameter_dict.items():
|
|
1432
|
+
pos = param_name.find(".__param_key__")
|
|
1433
|
+
if pos != -1:
|
|
1434
|
+
net_param_name = param_name[:pos]
|
|
1435
|
+
param_key_dict[param_name] = net_param_name
|
|
1436
|
+
# Split param key from parameter_dict since worker cannot load param key.
|
|
1437
|
+
warm_up_dict = {}
|
|
1438
|
+
for key, value in param_key_dict.items():
|
|
1439
|
+
if is_worker:
|
|
1440
|
+
warm_up_dict[value] = parameter_dict.pop(key)
|
|
1441
|
+
else:
|
|
1442
|
+
parameter_dict[value] = parameter_dict.pop(key)
|
|
1443
|
+
return (is_worker, parameter_dict, warm_up_dict)
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
|
|
1447
|
+
"""Warm up host cache post process."""
|
|
1448
|
+
if is_worker:
|
|
1449
|
+
net_dict.update(warm_up_dict)
|
|
1450
|
+
_set_checkpoint_load_status(True)
|
|
1451
|
+
|
|
1452
|
+
|
|
1246
1453
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
|
1247
1454
|
"""When some net parameter did not load, try to continue loading."""
|
|
1248
1455
|
prefix_name = ""
|
|
@@ -1350,9 +1557,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1350
1557
|
Note:
|
|
1351
1558
|
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
|
1352
1559
|
2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
|
|
1353
|
-
3. Exporting functions decorated with
|
|
1354
|
-
4. When exporting a function decorated with
|
|
1355
|
-
calculations.
|
|
1560
|
+
3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
|
|
1561
|
+
4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
|
|
1562
|
+
class properties in calculations.
|
|
1356
1563
|
|
|
1357
1564
|
Args:
|
|
1358
1565
|
net (Union[Cell, function]): MindSpore network.
|
|
@@ -1388,17 +1595,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1388
1595
|
|
|
1389
1596
|
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
1390
1597
|
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
1391
|
-
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
1598
|
+
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
1599
|
+
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
1392
1600
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1393
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean
|
|
1394
|
-
|
|
1601
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1602
|
+
Reference to 'my_func()' in
|
|
1603
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
|
|
1604
|
+
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1395
1605
|
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1396
1606
|
obfuscated model.
|
|
1397
|
-
- obf_random_seed (int):
|
|
1398
|
-
|
|
1399
|
-
then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
1400
|
-
be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1401
|
-
would be applied if both of them are set.
|
|
1607
|
+
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1608
|
+
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1609
|
+
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
1610
|
+
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1611
|
+
be set, and the latter mode would be applied if both of them are set.
|
|
1402
1612
|
|
|
1403
1613
|
- incremental (bool): export MindIR incrementally.
|
|
1404
1614
|
|
|
@@ -1408,14 +1618,14 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1408
1618
|
>>> from mindspore import Tensor
|
|
1409
1619
|
>>>
|
|
1410
1620
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1411
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
1621
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1412
1622
|
>>> net = LeNet5()
|
|
1413
1623
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1414
1624
|
>>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
|
|
1415
1625
|
|
|
1416
1626
|
Tutorial Examples:
|
|
1417
1627
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
1418
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1628
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
1419
1629
|
"""
|
|
1420
1630
|
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
1421
1631
|
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
@@ -1475,7 +1685,7 @@ def _get_funcgraph(net, *inputs):
|
|
|
1475
1685
|
>>> from mindspore import Tensor
|
|
1476
1686
|
>>>
|
|
1477
1687
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1478
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
1688
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1479
1689
|
>>> net = LeNet5()
|
|
1480
1690
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1481
1691
|
>>> ms.get_funcgraph(net, input_tensor)
|
|
@@ -1657,10 +1867,17 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1657
1867
|
data_file_name = os.path.join(dirname, external_local)
|
|
1658
1868
|
f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
|
|
1659
1869
|
try:
|
|
1870
|
+
round_ = 0
|
|
1871
|
+
names = []
|
|
1660
1872
|
for param_proto in model.graph.parameter:
|
|
1661
1873
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1874
|
+
names.append((name, param_proto))
|
|
1875
|
+
names.sort(key=lambda x: x[0])
|
|
1876
|
+
for pairs in names:
|
|
1877
|
+
name = pairs[0]
|
|
1878
|
+
param_proto = pairs[1]
|
|
1662
1879
|
param = net_dict[name]
|
|
1663
|
-
raw_data = param.data.
|
|
1880
|
+
raw_data = param.data.get_bytes()
|
|
1664
1881
|
data_length = len(raw_data)
|
|
1665
1882
|
append_size = 0
|
|
1666
1883
|
if data_length % 64 != 0:
|
|
@@ -1678,6 +1895,8 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1678
1895
|
offset += (data_length + append_size)
|
|
1679
1896
|
write_data = _encrypt_data(is_encrypt, write_data, kwargs)
|
|
1680
1897
|
f.write(write_data)
|
|
1898
|
+
round_ += 1
|
|
1899
|
+
logger.debug(f"writing {round_}th split data, name:{name}")
|
|
1681
1900
|
|
|
1682
1901
|
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
|
1683
1902
|
if os.path.exists(graph_file_name):
|
|
@@ -1787,7 +2006,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1787
2006
|
for param_proto in model.graph.parameter:
|
|
1788
2007
|
param_name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1789
2008
|
if param_name in net_dict.keys():
|
|
1790
|
-
param_data = net_dict[param_name].data.
|
|
2009
|
+
param_data = net_dict[param_name].data.get_bytes()
|
|
1791
2010
|
param_proto.raw_data = param_data
|
|
1792
2011
|
else:
|
|
1793
2012
|
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
@@ -1797,10 +2016,10 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1797
2016
|
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
|
1798
2017
|
if map_param_name in net_dict.keys():
|
|
1799
2018
|
map_parameter = net_dict[map_param_name]
|
|
1800
|
-
|
|
1801
|
-
map_param_proto.key_tensor.raw_data =
|
|
1802
|
-
map_param_proto.value_tensor.raw_data =
|
|
1803
|
-
map_param_proto.status_tensor.raw_data =
|
|
2019
|
+
key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
|
|
2020
|
+
map_param_proto.key_tensor.raw_data = key_bytes
|
|
2021
|
+
map_param_proto.value_tensor.raw_data = value_bytes
|
|
2022
|
+
map_param_proto.status_tensor.raw_data = status_bytes
|
|
1804
2023
|
else:
|
|
1805
2024
|
raise ValueError("The map_parameter '{}' is not belongs to any cell,"
|
|
1806
2025
|
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
@@ -1831,7 +2050,7 @@ def _save_together(net_dict, model):
|
|
|
1831
2050
|
for param_proto in model.graph.parameter:
|
|
1832
2051
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1833
2052
|
if name in net_dict.keys():
|
|
1834
|
-
data_total += sys.getsizeof(net_dict[name].data.
|
|
2053
|
+
data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
|
|
1835
2054
|
else:
|
|
1836
2055
|
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
1837
2056
|
"the data of parameter cannot be exported.".format(param_proto.name))
|
|
@@ -1862,7 +2081,7 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
1862
2081
|
|
|
1863
2082
|
def parse_print(print_file_name):
|
|
1864
2083
|
"""
|
|
1865
|
-
Parse data file generated by mindspore.ops.Print
|
|
2084
|
+
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
1866
2085
|
|
|
1867
2086
|
Args:
|
|
1868
2087
|
print_file_name (str): The file name needs to be parsed.
|
|
@@ -2039,8 +2258,8 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
|
2039
2258
|
def restore_group_info_list(group_info_file_name):
|
|
2040
2259
|
"""
|
|
2041
2260
|
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
|
2042
|
-
who saves the group_info_file_name
|
|
2043
|
-
like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2261
|
+
who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
|
|
2262
|
+
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2044
2263
|
|
|
2045
2264
|
Args:
|
|
2046
2265
|
group_info_file_name (str): Name of group information file.
|
|
@@ -2050,7 +2269,7 @@ def restore_group_info_list(group_info_file_name):
|
|
|
2050
2269
|
|
|
2051
2270
|
Raises:
|
|
2052
2271
|
ValueError: group information file is incorrect.
|
|
2053
|
-
TypeError: group_info_file_name is not str.
|
|
2272
|
+
TypeError: `group_info_file_name` is not str.
|
|
2054
2273
|
|
|
2055
2274
|
Examples:
|
|
2056
2275
|
>>> import mindspore as ms
|
|
@@ -2072,9 +2291,6 @@ def restore_group_info_list(group_info_file_name):
|
|
|
2072
2291
|
def build_searched_strategy(strategy_filename):
|
|
2073
2292
|
"""
|
|
2074
2293
|
Build strategy of every parameter in network. Used in the case of distributed inference.
|
|
2075
|
-
For details of it, please check:
|
|
2076
|
-
`Saving and Loading Models in Hybrid Parallel Mode
|
|
2077
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
|
|
2078
2294
|
|
|
2079
2295
|
Args:
|
|
2080
2296
|
strategy_filename (str): Name of strategy file.
|
|
@@ -2096,8 +2312,6 @@ def build_searched_strategy(strategy_filename):
|
|
|
2096
2312
|
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
2097
2313
|
"""
|
|
2098
2314
|
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
2099
|
-
For details of it, please check:
|
|
2100
|
-
`<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
|
|
2101
2315
|
|
|
2102
2316
|
Args:
|
|
2103
2317
|
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
@@ -2171,7 +2385,12 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
2171
2385
|
|
|
2172
2386
|
layerwise_parallel = sliced_parameters[0].layerwise_parallel
|
|
2173
2387
|
requires_grad = sliced_parameters[0].requires_grad
|
|
2174
|
-
sliced_data = [
|
|
2388
|
+
sliced_data = []
|
|
2389
|
+
for parameter in sliced_parameters:
|
|
2390
|
+
if parameter.data.dtype == mstype.bfloat16:
|
|
2391
|
+
sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
|
|
2392
|
+
else:
|
|
2393
|
+
sliced_data.append(parameter.data.asnumpy())
|
|
2175
2394
|
|
|
2176
2395
|
if not strategy:
|
|
2177
2396
|
merged_tensor = Tensor(np.concatenate(sliced_data))
|
|
@@ -2191,9 +2410,6 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2191
2410
|
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
|
|
2192
2411
|
"""
|
|
2193
2412
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
2194
|
-
For details of distributed inference, please check:
|
|
2195
|
-
`Distributed Inference
|
|
2196
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/distributed_inference.html>`_ .
|
|
2197
2413
|
|
|
2198
2414
|
Args:
|
|
2199
2415
|
network (Cell): Network for distributed predication.
|
|
@@ -2218,6 +2434,104 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2218
2434
|
Raises:
|
|
2219
2435
|
TypeError: The type of inputs do not match the requirements.
|
|
2220
2436
|
ValueError: Failed to load checkpoint into net.
|
|
2437
|
+
|
|
2438
|
+
Supported Platforms:
|
|
2439
|
+
``Ascend`` ``GPU``
|
|
2440
|
+
|
|
2441
|
+
Examples:
|
|
2442
|
+
.. note::
|
|
2443
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2444
|
+
|
|
2445
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2446
|
+
Please see the `rank table startup
|
|
2447
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
2448
|
+
for more details.
|
|
2449
|
+
|
|
2450
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2451
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
2452
|
+
|
|
2453
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2454
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
2455
|
+
|
|
2456
|
+
>>> import os
|
|
2457
|
+
>>> import numpy as np
|
|
2458
|
+
>>> import mindspore as ms
|
|
2459
|
+
>>> import mindspore.dataset as ds
|
|
2460
|
+
>>> from mindspore import nn, ops, train
|
|
2461
|
+
>>> from mindspore.communication import init
|
|
2462
|
+
>>>
|
|
2463
|
+
>>> step_per_epoch = 4
|
|
2464
|
+
>>> device_num = 8
|
|
2465
|
+
>>>
|
|
2466
|
+
>>> # Define the network structure.
|
|
2467
|
+
>>> class Net(nn.Cell):
|
|
2468
|
+
... def __init__(self, matmul_size, strategy=None):
|
|
2469
|
+
... super().__init__()
|
|
2470
|
+
... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
|
|
2471
|
+
... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
|
|
2472
|
+
... self.matmul = ops.MatMul()
|
|
2473
|
+
... self.neg = ops.Neg()
|
|
2474
|
+
... if strategy is not None:
|
|
2475
|
+
... self.matmul.shard(strategy)
|
|
2476
|
+
...
|
|
2477
|
+
... def construct(self, inputs):
|
|
2478
|
+
... x = self.matmul(inputs, self.matmul_weight)
|
|
2479
|
+
... x = self.neg(x)
|
|
2480
|
+
... return x
|
|
2481
|
+
>>>
|
|
2482
|
+
>>> # Create dataset.
|
|
2483
|
+
>>> def get_dataset(*inputs):
|
|
2484
|
+
... def generate():
|
|
2485
|
+
... for _ in range(step_per_epoch):
|
|
2486
|
+
... yield inputs
|
|
2487
|
+
... return generate
|
|
2488
|
+
>>>
|
|
2489
|
+
>>> # Train network and save distributed checkpoint.
|
|
2490
|
+
>>> def train_net():
|
|
2491
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2492
|
+
... init()
|
|
2493
|
+
... np.random.seed(1)
|
|
2494
|
+
... input_data = np.random.rand(16, 96).astype(np.float32)
|
|
2495
|
+
... label_data = np.random.rand(16, 16).astype(np.float32)
|
|
2496
|
+
... fake_dataset = get_dataset(input_data, label_data)
|
|
2497
|
+
... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
|
2498
|
+
...
|
|
2499
|
+
... # Set parallel strategy.
|
|
2500
|
+
... strategy = ((1, 4), (4, 1))
|
|
2501
|
+
... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
|
|
2502
|
+
... strategy_ckpt_save_file="./train_strategy.ckpt")
|
|
2503
|
+
... network = Net(matmul_size=(96, 16), strategy=strategy)
|
|
2504
|
+
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
|
2505
|
+
... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
2506
|
+
... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
|
|
2507
|
+
... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
|
|
2508
|
+
... global_rank_id = int(os.getenv("RANK_ID"))
|
|
2509
|
+
... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
|
|
2510
|
+
... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
|
|
2511
|
+
... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
|
|
2512
|
+
... ms.reset_auto_parallel_context()
|
|
2513
|
+
>>>
|
|
2514
|
+
>>> # Load distributed checkpoint and test.
|
|
2515
|
+
>>> def load_model():
|
|
2516
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2517
|
+
... init()
|
|
2518
|
+
... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
|
|
2519
|
+
... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
|
|
2520
|
+
... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
|
|
2521
|
+
... network = Net(matmul_size=(96, 16))
|
|
2522
|
+
... model = ms.Model(network)
|
|
2523
|
+
... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
|
|
2524
|
+
... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
|
|
2525
|
+
... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
|
|
2526
|
+
... predict_result = model.predict(predict_data)
|
|
2527
|
+
... print(predict_result)
|
|
2528
|
+
>>>
|
|
2529
|
+
>>> train_net()
|
|
2530
|
+
>>> load_model()
|
|
2531
|
+
[[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
|
|
2532
|
+
[ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
|
|
2533
|
+
...
|
|
2534
|
+
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
2221
2535
|
"""
|
|
2222
2536
|
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
2223
2537
|
_check_checkpoint_file(checkpoint_filenames)
|
|
@@ -2282,7 +2596,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2282
2596
|
param_index = list(set(param_index))
|
|
2283
2597
|
param_index.sort()
|
|
2284
2598
|
for rank_num in param_index:
|
|
2285
|
-
|
|
2599
|
+
if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
|
|
2600
|
+
param_stride.append(
|
|
2601
|
+
cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
|
|
2602
|
+
else:
|
|
2603
|
+
param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
|
|
2286
2604
|
|
|
2287
2605
|
sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
|
|
2288
2606
|
else:
|
|
@@ -2297,7 +2615,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2297
2615
|
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
|
|
2298
2616
|
opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
|
|
2299
2617
|
if opt_shard_group:
|
|
2300
|
-
|
|
2618
|
+
if split_param.data.dtype == mstype.bfloat16:
|
|
2619
|
+
data = cpu_cast(split_param.data, mstype.float32).asnumpy()
|
|
2620
|
+
else:
|
|
2621
|
+
data = split_param.data.asnumpy()
|
|
2301
2622
|
rank = get_rank(opt_shard_group)
|
|
2302
2623
|
size = get_group_size(opt_shard_group)
|
|
2303
2624
|
try:
|
|
@@ -2395,10 +2716,15 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
2395
2716
|
return merged_param
|
|
2396
2717
|
param_name = merged_param.name
|
|
2397
2718
|
tensor_layout = predict_strategy[param_name]
|
|
2398
|
-
|
|
2719
|
+
rank = get_rank()
|
|
2720
|
+
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
|
|
2399
2721
|
requires_grad = merged_param.requires_grad
|
|
2400
2722
|
layerwise_parallel = merged_param.layerwise_parallel
|
|
2401
|
-
|
|
2723
|
+
data_type = merged_param.data.dtype
|
|
2724
|
+
if data_type == mstype.bfloat16:
|
|
2725
|
+
split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
|
|
2726
|
+
else:
|
|
2727
|
+
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
2402
2728
|
return split_param
|
|
2403
2729
|
|
|
2404
2730
|
|
|
@@ -2407,7 +2733,7 @@ def _calculation_net_size(net):
|
|
|
2407
2733
|
data_total = 0
|
|
2408
2734
|
net_dict = net.parameters_dict()
|
|
2409
2735
|
for name in net_dict:
|
|
2410
|
-
data_total += sys.getsizeof(net_dict[name].data.
|
|
2736
|
+
data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
|
|
2411
2737
|
|
|
2412
2738
|
return data_total
|
|
2413
2739
|
|