mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -23,7 +23,7 @@ import os
|
|
|
23
23
|
import shutil
|
|
24
24
|
import stat
|
|
25
25
|
import threading
|
|
26
|
-
from threading import Thread,
|
|
26
|
+
from threading import Thread, RLock
|
|
27
27
|
from collections import defaultdict, OrderedDict
|
|
28
28
|
from io import BytesIO
|
|
29
29
|
|
|
@@ -59,9 +59,11 @@ from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_
|
|
|
59
59
|
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
|
|
60
60
|
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
61
61
|
_restore_group_info_list
|
|
62
|
+
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
63
|
+
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
62
64
|
from mindspore.train._utils import read_proto
|
|
63
65
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
64
|
-
split_mindir
|
|
66
|
+
split_mindir, split_dynamic_mindir
|
|
65
67
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
66
68
|
|
|
67
69
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
@@ -79,7 +81,7 @@ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4:
|
|
|
79
81
|
5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
|
|
80
82
|
11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
|
|
81
83
|
|
|
82
|
-
_ckpt_mutex =
|
|
84
|
+
_ckpt_mutex = RLock()
|
|
83
85
|
|
|
84
86
|
# unit is KB
|
|
85
87
|
SLICE_SIZE = 512 * 1024
|
|
@@ -333,8 +335,8 @@ def _write_hugeparameter(name, value, f):
|
|
|
333
335
|
|
|
334
336
|
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
335
337
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
336
|
-
if not isinstance(save_obj, nn.Cell
|
|
337
|
-
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell or
|
|
338
|
+
if not isinstance(save_obj, (nn.Cell, list, dict)):
|
|
339
|
+
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
|
|
338
340
|
"but got {}.".format(type(save_obj)))
|
|
339
341
|
if not isinstance(ckpt_file_name, str):
|
|
340
342
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
@@ -351,14 +353,15 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
351
353
|
|
|
352
354
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
353
355
|
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
|
|
354
|
-
"""
|
|
356
|
+
r"""
|
|
355
357
|
Save checkpoint to a specified file.
|
|
356
358
|
|
|
357
359
|
Args:
|
|
358
|
-
save_obj (Union[Cell, list]): The
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
360
|
+
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
361
|
+
list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
362
|
+
elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
|
|
363
|
+
`param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
|
|
364
|
+
it can be the returned value of `mindspore.load_checkpoint()`.
|
|
362
365
|
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
|
363
366
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
|
|
364
367
|
async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
|
|
@@ -370,16 +373,14 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
370
373
|
mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
371
374
|
Default: ``"AES-GCM"`` .
|
|
372
375
|
choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
|
|
373
|
-
a parameter name in string type, and the
|
|
376
|
+
a parameter name in string type, and the returned value is a bool.
|
|
374
377
|
If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
375
378
|
If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
376
379
|
be saved. Default: ``None`` .
|
|
377
380
|
kwargs (dict): Configuration options dictionary.
|
|
378
381
|
|
|
379
|
-
- incremental (bool): Whether export checkpoint for MapParameter incrementally.
|
|
380
|
-
|
|
381
382
|
Raises:
|
|
382
|
-
TypeError: If the parameter `save_obj` is not
|
|
383
|
+
TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
|
|
383
384
|
TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
|
|
384
385
|
TypeError: If the parameter `ckpt_file_name` is not string type.
|
|
385
386
|
|
|
@@ -387,17 +388,27 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
387
388
|
>>> import mindspore as ms
|
|
388
389
|
>>>
|
|
389
390
|
>>> # Define the network structure of LeNet5. Refer to
|
|
390
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
391
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
391
392
|
>>> net = LeNet5()
|
|
392
393
|
>>> ms.save_checkpoint(net, "./lenet.ckpt",
|
|
393
|
-
|
|
394
|
-
>>>
|
|
395
|
-
>>> print(
|
|
394
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
395
|
+
>>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
|
|
396
|
+
>>> print(param_dict1)
|
|
396
397
|
{'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
|
|
398
|
+
>>> params_list = net.trainable_params()
|
|
399
|
+
>>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
|
|
400
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
|
|
401
|
+
>>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
|
|
402
|
+
>>> print(param_dict2)
|
|
403
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
404
|
+
>>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
|
|
405
|
+
>>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
|
|
406
|
+
>>> print(param_dict3)
|
|
407
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
397
408
|
|
|
398
409
|
Tutorial Examples:
|
|
399
410
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
400
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
411
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
401
412
|
"""
|
|
402
413
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
403
414
|
integrated_save = Validator.check_bool(integrated_save)
|
|
@@ -408,70 +419,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
408
419
|
map_param_inc = kwargs.get('incremental', False)
|
|
409
420
|
logger.info("Execute the process of saving checkpoint files.")
|
|
410
421
|
|
|
411
|
-
|
|
412
|
-
if save_obj.ge_init and not save_obj.ge_sync_data:
|
|
413
|
-
from mindspore.train.callback._callback import set_cur_net
|
|
414
|
-
set_cur_net(save_obj)
|
|
415
|
-
save_obj.exec_checkpoint_graph()
|
|
416
|
-
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
417
|
-
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
418
|
-
parameter_layout_dict = _get_parameter_layout()
|
|
419
|
-
if not _is_in_auto_parallel_mode():
|
|
420
|
-
save_obj.init_parameters_data()
|
|
421
|
-
param_dict = OrderedDict()
|
|
422
|
-
for _, param in save_obj.parameters_and_names():
|
|
423
|
-
not_sliced = not param.sliced
|
|
424
|
-
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
425
|
-
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
426
|
-
if is_graph_mode and _is_in_auto_parallel_mode() and (not_sliced or param.has_init):
|
|
427
|
-
continue
|
|
428
|
-
if choice_func is not None and not choice_func(param.name):
|
|
429
|
-
continue
|
|
430
|
-
param_dict[param.name] = param
|
|
431
|
-
param_list = []
|
|
432
|
-
if append_dict and "random_op" in append_dict:
|
|
433
|
-
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
434
|
-
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
435
|
-
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
436
|
-
param_list.append({"name": "random_op", "data": random_byte})
|
|
437
|
-
append_dict.pop("random_op")
|
|
438
|
-
for (key, value) in param_dict.items():
|
|
439
|
-
each_param = {"name": key}
|
|
440
|
-
if isinstance(value, MapParameter):
|
|
441
|
-
each_param["data"] = value
|
|
442
|
-
param_list.append(each_param)
|
|
443
|
-
continue
|
|
444
|
-
|
|
445
|
-
if value.data.is_persistent_data():
|
|
446
|
-
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
447
|
-
param_data = ["persistent_data"]
|
|
448
|
-
param_data.append(value.data)
|
|
449
|
-
param_data.append(value.param_info.origin_shape)
|
|
450
|
-
param_data.append(str(value.dtype))
|
|
451
|
-
param_data.append(value.key)
|
|
452
|
-
elif value.data.offload_file_path() != "":
|
|
453
|
-
# list save offload data: [Param, shape, type, param.key]
|
|
454
|
-
param_data = ["offload_parameter"]
|
|
455
|
-
param_tensor = value.data
|
|
456
|
-
if key in parameter_layout_dict:
|
|
457
|
-
param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
|
|
458
|
-
integrated_save)
|
|
459
|
-
param_data.append(param_tensor)
|
|
460
|
-
param_data.append(param_tensor.shape)
|
|
461
|
-
param_data.append(str(param_tensor.dtype))
|
|
462
|
-
param_data.append(value.key)
|
|
463
|
-
else:
|
|
464
|
-
param_data = Tensor(value.data.asnumpy())
|
|
465
|
-
|
|
466
|
-
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
467
|
-
# which should be combined before saving
|
|
468
|
-
if key in parameter_layout_dict:
|
|
469
|
-
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
470
|
-
integrated_save)
|
|
471
|
-
|
|
472
|
-
each_param["data"] = param_data
|
|
473
|
-
param_list.append(each_param)
|
|
474
|
-
save_obj = param_list
|
|
422
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
475
423
|
|
|
476
424
|
if append_dict:
|
|
477
425
|
append_info_list = []
|
|
@@ -479,7 +427,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
479
427
|
if not isinstance(value, str):
|
|
480
428
|
value = Tensor(value)
|
|
481
429
|
append_info_list.append({"name": k_name, "data": value})
|
|
482
|
-
|
|
430
|
+
save_obj.extend(append_info_list)
|
|
483
431
|
|
|
484
432
|
data_list = OrderedDict()
|
|
485
433
|
with _ckpt_mutex:
|
|
@@ -530,6 +478,124 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
530
478
|
logger.info("Saving checkpoint process is finished.")
|
|
531
479
|
|
|
532
480
|
|
|
481
|
+
def _convert_list_to_param_list(save_obj, choice_func):
|
|
482
|
+
"""Convert a list of Parameter to param_list."""
|
|
483
|
+
param_list = []
|
|
484
|
+
if not save_obj:
|
|
485
|
+
return param_list
|
|
486
|
+
if isinstance(save_obj[0], dict):
|
|
487
|
+
param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
|
|
488
|
+
else:
|
|
489
|
+
for param in save_obj:
|
|
490
|
+
if isinstance(param, Parameter):
|
|
491
|
+
if choice_func is not None and not choice_func(param.name):
|
|
492
|
+
continue
|
|
493
|
+
each_param = {"name": param.name, "data": param}
|
|
494
|
+
param_list.append(each_param)
|
|
495
|
+
else:
|
|
496
|
+
raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
|
|
497
|
+
f"the param should be parameter, but got {type(param)}")
|
|
498
|
+
return param_list
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
502
|
+
"""Convert a dict of Parameter to param_list."""
|
|
503
|
+
param_list = []
|
|
504
|
+
for (key, value) in save_obj.items():
|
|
505
|
+
if isinstance(key, str) and isinstance(value, (Parameter, str)):
|
|
506
|
+
if choice_func is not None and not choice_func(key):
|
|
507
|
+
continue
|
|
508
|
+
each_param = {"name": key, "data": value}
|
|
509
|
+
param_list.append(each_param)
|
|
510
|
+
else:
|
|
511
|
+
raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
|
|
512
|
+
f"value should be Parameter, but got the type of key is {type(key)} and"
|
|
513
|
+
f"the type of value is {type(value)}")
|
|
514
|
+
return param_list
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
518
|
+
"""Convert cell.parameters_and_names to OrderedDict."""
|
|
519
|
+
param_dict = OrderedDict()
|
|
520
|
+
for _, param in save_obj.parameters_and_names():
|
|
521
|
+
not_sliced = not param.sliced
|
|
522
|
+
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
523
|
+
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
524
|
+
judgment = not_sliced or param.has_init
|
|
525
|
+
if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
|
|
526
|
+
continue
|
|
527
|
+
if choice_func is not None and not choice_func(param.name):
|
|
528
|
+
continue
|
|
529
|
+
# Add suffix for cache_enabled parameter, and then parameter can carry key info.
|
|
530
|
+
# Notice that suffix needs be removed when loading into net.
|
|
531
|
+
if param.cache_enable:
|
|
532
|
+
param_dict[param.name + ".__param_key__" + str(param.key)] = param
|
|
533
|
+
else:
|
|
534
|
+
param_dict[param.name] = param
|
|
535
|
+
return param_dict
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
539
|
+
"""Convert nn.Cell to param_list."""
|
|
540
|
+
param_list = []
|
|
541
|
+
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
542
|
+
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
543
|
+
parameter_layout_dict = _get_parameter_layout()
|
|
544
|
+
if not _is_in_auto_parallel_mode():
|
|
545
|
+
save_obj.init_parameters_data()
|
|
546
|
+
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
|
|
547
|
+
if append_dict and "random_op" in append_dict:
|
|
548
|
+
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
549
|
+
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
550
|
+
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
551
|
+
param_list.append({"name": "random_op", "data": random_byte})
|
|
552
|
+
append_dict.pop("random_op")
|
|
553
|
+
for (key, value) in param_dict.items():
|
|
554
|
+
each_param = {"name": key}
|
|
555
|
+
if isinstance(value, MapParameter):
|
|
556
|
+
each_param["data"] = value
|
|
557
|
+
param_list.append(each_param)
|
|
558
|
+
continue
|
|
559
|
+
|
|
560
|
+
if value.data.is_persistent_data():
|
|
561
|
+
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
562
|
+
param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
|
|
563
|
+
elif value.data.offload_file_path() != "":
|
|
564
|
+
# list save offload data: [Param, shape, type, param.key]
|
|
565
|
+
param_data = ["offload_parameter"]
|
|
566
|
+
param_tensor = value.data
|
|
567
|
+
if key in parameter_layout_dict:
|
|
568
|
+
param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
|
|
569
|
+
integrated_save)
|
|
570
|
+
param_data.append(param_tensor)
|
|
571
|
+
param_data.append(param_tensor.shape)
|
|
572
|
+
param_data.append(str(param_tensor.dtype))
|
|
573
|
+
param_data.append(value.key)
|
|
574
|
+
else:
|
|
575
|
+
param_data = Tensor(value.data.asnumpy())
|
|
576
|
+
|
|
577
|
+
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
578
|
+
# which should be combined before saving
|
|
579
|
+
if key in parameter_layout_dict:
|
|
580
|
+
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
581
|
+
integrated_save)
|
|
582
|
+
|
|
583
|
+
each_param["data"] = param_data
|
|
584
|
+
param_list.append(each_param)
|
|
585
|
+
return param_list
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
589
|
+
"""Convert a save_obj to param_list."""
|
|
590
|
+
if isinstance(save_obj, list):
|
|
591
|
+
return _convert_list_to_param_list(save_obj, choice_func)
|
|
592
|
+
|
|
593
|
+
if isinstance(save_obj, dict):
|
|
594
|
+
return _convert_dict_to_param_dict(save_obj, choice_func)
|
|
595
|
+
|
|
596
|
+
return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
597
|
+
|
|
598
|
+
|
|
533
599
|
def _save_param_list_data(data_list, key, param):
|
|
534
600
|
"""Save persistent data into save_obj."""
|
|
535
601
|
dims = []
|
|
@@ -585,7 +651,7 @@ def load(file_name, **kwargs):
|
|
|
585
651
|
|
|
586
652
|
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
587
653
|
`obfuscate_model()
|
|
588
|
-
<https://www.mindspore.cn/docs/en/r2.
|
|
654
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
589
655
|
|
|
590
656
|
Returns:
|
|
591
657
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
@@ -615,7 +681,7 @@ def load(file_name, **kwargs):
|
|
|
615
681
|
|
|
616
682
|
Tutorial Examples:
|
|
617
683
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
618
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
684
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
619
685
|
"""
|
|
620
686
|
if not isinstance(file_name, str):
|
|
621
687
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
@@ -656,7 +722,7 @@ def load(file_name, **kwargs):
|
|
|
656
722
|
return graph
|
|
657
723
|
|
|
658
724
|
|
|
659
|
-
def export_split_mindir(file_name):
|
|
725
|
+
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
|
|
660
726
|
"""
|
|
661
727
|
Auto Split MindIR.
|
|
662
728
|
|
|
@@ -664,6 +730,10 @@ def export_split_mindir(file_name):
|
|
|
664
730
|
|
|
665
731
|
Args:
|
|
666
732
|
file_name (str): MindIR file name.
|
|
733
|
+
device_num (int): device number.
|
|
734
|
+
rank_id (int): rank id.
|
|
735
|
+
dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
|
|
736
|
+
sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
|
|
667
737
|
|
|
668
738
|
Raises:
|
|
669
739
|
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
@@ -671,11 +741,9 @@ def export_split_mindir(file_name):
|
|
|
671
741
|
|
|
672
742
|
Examples:
|
|
673
743
|
>>> import mindspore as ms
|
|
674
|
-
>>> from mindspore.communication import init
|
|
675
744
|
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
676
745
|
>>>
|
|
677
|
-
>>>
|
|
678
|
-
>>> ms.export_split_mindir("net.mindir")
|
|
746
|
+
>>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
|
|
679
747
|
|
|
680
748
|
"""
|
|
681
749
|
if not isinstance(file_name, str):
|
|
@@ -690,8 +758,11 @@ def export_split_mindir(file_name):
|
|
|
690
758
|
file_name = os.path.abspath(file_name)
|
|
691
759
|
|
|
692
760
|
logger.info("Execute the process of export and split mindir.")
|
|
693
|
-
|
|
694
|
-
|
|
761
|
+
dynamic = True
|
|
762
|
+
if dynamic:
|
|
763
|
+
graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
|
|
764
|
+
else:
|
|
765
|
+
graph = split_mindir(file_name)
|
|
695
766
|
|
|
696
767
|
if graph is None:
|
|
697
768
|
if _is_cipher_file(file_name):
|
|
@@ -779,17 +850,20 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
779
850
|
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
|
780
851
|
is the same as using :func:`mindspore.export`.
|
|
781
852
|
- obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
782
|
-
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
853
|
+
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
854
|
+
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
783
855
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
784
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean
|
|
785
|
-
|
|
856
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
857
|
+
Reference to 'my_func()' in
|
|
858
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
|
|
859
|
+
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
786
860
|
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
787
861
|
when loading obfuscated model.
|
|
788
|
-
- obf_random_seed (int):
|
|
789
|
-
|
|
790
|
-
then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
791
|
-
noted that at least one of `customized_func` or `obf_random_seed` should
|
|
792
|
-
would be applied if both of them are set.
|
|
862
|
+
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
863
|
+
structure of obfuscated models corresponding to different random seeds is different. If
|
|
864
|
+
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
865
|
+
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
866
|
+
be set, and the latter mode would be applied if both of them are set.
|
|
793
867
|
|
|
794
868
|
kwargs (dict): Configuration options dictionary.
|
|
795
869
|
|
|
@@ -928,27 +1002,27 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
928
1002
|
>>> print(param_dict["conv2.weight"])
|
|
929
1003
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
930
1004
|
>>> def func(param_name):
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
1005
|
+
... whether_load = False
|
|
1006
|
+
... if param_name.startswith("conv"):
|
|
1007
|
+
... whether_load = True
|
|
1008
|
+
... if param_name.startswith("conv1"):
|
|
1009
|
+
... whether_load = False
|
|
1010
|
+
... return whether_load
|
|
937
1011
|
>>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
938
1012
|
>>> print(param_dict1["conv2.weight"])
|
|
939
1013
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
940
1014
|
>>> def func(param_name):
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
1015
|
+
... whether_load = False
|
|
1016
|
+
... if param_name.startswith("conv1"):
|
|
1017
|
+
... whether_load = True
|
|
1018
|
+
... return whether_load
|
|
945
1019
|
>>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
946
1020
|
>>> print(param_dict2)
|
|
947
1021
|
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
948
1022
|
|
|
949
1023
|
Tutorial Examples:
|
|
950
1024
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
951
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1025
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
952
1026
|
"""
|
|
953
1027
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
954
1028
|
specify_prefix = _check_prefix(specify_prefix)
|
|
@@ -979,8 +1053,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
979
1053
|
choice_func is not None and not choice_func(element.tag):
|
|
980
1054
|
continue
|
|
981
1055
|
if element.tensor.ByteSize() == 0:
|
|
982
|
-
_load_map_parameter(checkpoint_list, element, element_id,
|
|
983
|
-
map_data_list, map_shape_list, parameter_dict)
|
|
1056
|
+
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
|
|
984
1057
|
if element.tag in parameter_dict:
|
|
985
1058
|
map_data_list = [[], [], []]
|
|
986
1059
|
map_shape_list = [0, 0, 0]
|
|
@@ -1024,8 +1097,12 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1024
1097
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
1025
1098
|
f"'filter_prefix' or 'specify_prefix' are set correctly.")
|
|
1026
1099
|
|
|
1100
|
+
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1101
|
+
(is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
|
|
1027
1102
|
if net is not None:
|
|
1028
1103
|
load_param_into_net(net, parameter_dict, strict_load)
|
|
1104
|
+
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1105
|
+
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
1029
1106
|
|
|
1030
1107
|
return parameter_dict
|
|
1031
1108
|
|
|
@@ -1061,7 +1138,7 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
|
1061
1138
|
|
|
1062
1139
|
|
|
1063
1140
|
def _check_ckpt_file_name(ckpt_file_name):
|
|
1064
|
-
"""Check function load_checkpoint's
|
|
1141
|
+
"""Check function load_checkpoint's ckpt_file_name."""
|
|
1065
1142
|
if not isinstance(ckpt_file_name, str):
|
|
1066
1143
|
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
|
1067
1144
|
"but got {}.".format(type(ckpt_file_name)))
|
|
@@ -1175,7 +1252,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1175
1252
|
>>> import mindspore as ms
|
|
1176
1253
|
>>>
|
|
1177
1254
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1178
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
1255
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1179
1256
|
>>> net = LeNet5()
|
|
1180
1257
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1181
1258
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
@@ -1185,7 +1262,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1185
1262
|
|
|
1186
1263
|
Tutorial Examples:
|
|
1187
1264
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1188
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1265
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1189
1266
|
"""
|
|
1190
1267
|
if not isinstance(net, nn.Cell):
|
|
1191
1268
|
logger.critical("Failed to combine the net and the parameters.")
|
|
@@ -1219,6 +1296,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1219
1296
|
if isinstance(param, MapParameter):
|
|
1220
1297
|
param.import_data(parameter_dict[param.name])
|
|
1221
1298
|
continue
|
|
1299
|
+
# Add has attr protection when load server checkpoint file on worker.
|
|
1300
|
+
if not hasattr(parameter_dict[param.name], "data"):
|
|
1301
|
+
continue
|
|
1222
1302
|
new_param = copy.deepcopy(parameter_dict[param.name])
|
|
1223
1303
|
_update_param(param, new_param, strict_load)
|
|
1224
1304
|
ckpt_not_load.remove(param.name)
|
|
@@ -1243,6 +1323,72 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1243
1323
|
return param_not_load, ckpt_not_load
|
|
1244
1324
|
|
|
1245
1325
|
|
|
1326
|
+
def _warm_up_host_cache_enabled(parameter_dict):
|
|
1327
|
+
"""Warm up host cache enabled."""
|
|
1328
|
+
if _cache_enable():
|
|
1329
|
+
return True
|
|
1330
|
+
for key in parameter_dict.keys():
|
|
1331
|
+
if key.find(".__param_key__") != -1:
|
|
1332
|
+
return True
|
|
1333
|
+
return False
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
def _warm_up_host_cache(parameter_dict, net):
|
|
1337
|
+
"""Warm up host cache."""
|
|
1338
|
+
ms_role = os.getenv("MS_ROLE")
|
|
1339
|
+
is_worker = ms_role == "MS_WORKER"
|
|
1340
|
+
param_key_dict = {}
|
|
1341
|
+
# Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
|
|
1342
|
+
if is_worker:
|
|
1343
|
+
net.init_parameters_data()
|
|
1344
|
+
net_dict = {}
|
|
1345
|
+
for name, value in net.parameters_and_names():
|
|
1346
|
+
net_dict[name] = value
|
|
1347
|
+
for param_name, value in parameter_dict.items():
|
|
1348
|
+
pos = param_name.find(".__param_key__")
|
|
1349
|
+
if pos != -1:
|
|
1350
|
+
net_param_name = param_name[:pos]
|
|
1351
|
+
param_key_dict[param_name] = net_param_name
|
|
1352
|
+
net_value = None
|
|
1353
|
+
if net_param_name not in net_dict:
|
|
1354
|
+
logger.warning("net param name : %s is not in net", net_param_name)
|
|
1355
|
+
else:
|
|
1356
|
+
net_value = net_dict.get(net_param_name, None)
|
|
1357
|
+
pos += len(".__param_key__")
|
|
1358
|
+
param_key = int(param_name[pos:])
|
|
1359
|
+
value_is_map_parameter = isinstance(value, list) and len(value) == 3
|
|
1360
|
+
if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
|
|
1361
|
+
key_tensor = Tensor.from_numpy(value[0])
|
|
1362
|
+
value_tensor = Tensor.from_numpy(value[1])
|
|
1363
|
+
status_tensor = Tensor.from_numpy(value[2])
|
|
1364
|
+
_store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
|
|
1365
|
+
elif not isinstance(value, list) and isinstance(net_value, Parameter):
|
|
1366
|
+
_store_warm_up_ptr_by_tensor(param_key, value)
|
|
1367
|
+
else:
|
|
1368
|
+
logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
|
|
1369
|
+
else:
|
|
1370
|
+
for param_name, value in parameter_dict.items():
|
|
1371
|
+
pos = param_name.find(".__param_key__")
|
|
1372
|
+
if pos != -1:
|
|
1373
|
+
net_param_name = param_name[:pos]
|
|
1374
|
+
param_key_dict[param_name] = net_param_name
|
|
1375
|
+
# Split param key from parameter_dict since worker cannot load param key.
|
|
1376
|
+
warm_up_dict = {}
|
|
1377
|
+
for key, value in param_key_dict.items():
|
|
1378
|
+
if is_worker:
|
|
1379
|
+
warm_up_dict[value] = parameter_dict.pop(key)
|
|
1380
|
+
else:
|
|
1381
|
+
parameter_dict[value] = parameter_dict.pop(key)
|
|
1382
|
+
return (is_worker, parameter_dict, warm_up_dict)
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
|
|
1386
|
+
"""Warm up host cache post process."""
|
|
1387
|
+
if is_worker:
|
|
1388
|
+
net_dict.update(warm_up_dict)
|
|
1389
|
+
_set_checkpoint_load_status(True)
|
|
1390
|
+
|
|
1391
|
+
|
|
1246
1392
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
|
1247
1393
|
"""When some net parameter did not load, try to continue loading."""
|
|
1248
1394
|
prefix_name = ""
|
|
@@ -1350,9 +1496,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1350
1496
|
Note:
|
|
1351
1497
|
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
|
1352
1498
|
2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
|
|
1353
|
-
3. Exporting functions decorated with
|
|
1354
|
-
4. When exporting a function decorated with
|
|
1355
|
-
calculations.
|
|
1499
|
+
3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
|
|
1500
|
+
4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
|
|
1501
|
+
class properties in calculations.
|
|
1356
1502
|
|
|
1357
1503
|
Args:
|
|
1358
1504
|
net (Union[Cell, function]): MindSpore network.
|
|
@@ -1388,17 +1534,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1388
1534
|
|
|
1389
1535
|
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
1390
1536
|
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
1391
|
-
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
1537
|
+
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
1538
|
+
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
1392
1539
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1393
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean
|
|
1394
|
-
|
|
1540
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1541
|
+
Reference to 'my_func()' in
|
|
1542
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
|
|
1543
|
+
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1395
1544
|
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1396
1545
|
obfuscated model.
|
|
1397
|
-
- obf_random_seed (int):
|
|
1398
|
-
|
|
1399
|
-
then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
1400
|
-
be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1401
|
-
would be applied if both of them are set.
|
|
1546
|
+
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1547
|
+
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1548
|
+
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
1549
|
+
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1550
|
+
be set, and the latter mode would be applied if both of them are set.
|
|
1402
1551
|
|
|
1403
1552
|
- incremental (bool): export MindIR incrementally.
|
|
1404
1553
|
|
|
@@ -1408,14 +1557,14 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1408
1557
|
>>> from mindspore import Tensor
|
|
1409
1558
|
>>>
|
|
1410
1559
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1411
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
1560
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1412
1561
|
>>> net = LeNet5()
|
|
1413
1562
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1414
1563
|
>>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
|
|
1415
1564
|
|
|
1416
1565
|
Tutorial Examples:
|
|
1417
1566
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
1418
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1567
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
1419
1568
|
"""
|
|
1420
1569
|
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
1421
1570
|
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
@@ -1475,7 +1624,7 @@ def _get_funcgraph(net, *inputs):
|
|
|
1475
1624
|
>>> from mindspore import Tensor
|
|
1476
1625
|
>>>
|
|
1477
1626
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1478
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
1627
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1479
1628
|
>>> net = LeNet5()
|
|
1480
1629
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1481
1630
|
>>> ms.get_funcgraph(net, input_tensor)
|
|
@@ -1660,7 +1809,7 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1660
1809
|
for param_proto in model.graph.parameter:
|
|
1661
1810
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1662
1811
|
param = net_dict[name]
|
|
1663
|
-
raw_data = param.data.
|
|
1812
|
+
raw_data = param.data.get_bytes()
|
|
1664
1813
|
data_length = len(raw_data)
|
|
1665
1814
|
append_size = 0
|
|
1666
1815
|
if data_length % 64 != 0:
|
|
@@ -1787,7 +1936,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1787
1936
|
for param_proto in model.graph.parameter:
|
|
1788
1937
|
param_name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1789
1938
|
if param_name in net_dict.keys():
|
|
1790
|
-
param_data = net_dict[param_name].data.
|
|
1939
|
+
param_data = net_dict[param_name].data.get_bytes()
|
|
1791
1940
|
param_proto.raw_data = param_data
|
|
1792
1941
|
else:
|
|
1793
1942
|
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
@@ -1797,10 +1946,10 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1797
1946
|
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
|
1798
1947
|
if map_param_name in net_dict.keys():
|
|
1799
1948
|
map_parameter = net_dict[map_param_name]
|
|
1800
|
-
|
|
1801
|
-
map_param_proto.key_tensor.raw_data =
|
|
1802
|
-
map_param_proto.value_tensor.raw_data =
|
|
1803
|
-
map_param_proto.status_tensor.raw_data =
|
|
1949
|
+
key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
|
|
1950
|
+
map_param_proto.key_tensor.raw_data = key_bytes
|
|
1951
|
+
map_param_proto.value_tensor.raw_data = value_bytes
|
|
1952
|
+
map_param_proto.status_tensor.raw_data = status_bytes
|
|
1804
1953
|
else:
|
|
1805
1954
|
raise ValueError("The map_parameter '{}' is not belongs to any cell,"
|
|
1806
1955
|
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
@@ -1831,7 +1980,7 @@ def _save_together(net_dict, model):
|
|
|
1831
1980
|
for param_proto in model.graph.parameter:
|
|
1832
1981
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1833
1982
|
if name in net_dict.keys():
|
|
1834
|
-
data_total += sys.getsizeof(net_dict[name].data.
|
|
1983
|
+
data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
|
|
1835
1984
|
else:
|
|
1836
1985
|
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
1837
1986
|
"the data of parameter cannot be exported.".format(param_proto.name))
|
|
@@ -1862,7 +2011,7 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
1862
2011
|
|
|
1863
2012
|
def parse_print(print_file_name):
|
|
1864
2013
|
"""
|
|
1865
|
-
Parse data file generated by mindspore.ops.Print
|
|
2014
|
+
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
1866
2015
|
|
|
1867
2016
|
Args:
|
|
1868
2017
|
print_file_name (str): The file name needs to be parsed.
|
|
@@ -2039,8 +2188,8 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
|
2039
2188
|
def restore_group_info_list(group_info_file_name):
|
|
2040
2189
|
"""
|
|
2041
2190
|
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
|
2042
|
-
who saves the group_info_file_name
|
|
2043
|
-
like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2191
|
+
who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
|
|
2192
|
+
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2044
2193
|
|
|
2045
2194
|
Args:
|
|
2046
2195
|
group_info_file_name (str): Name of group information file.
|
|
@@ -2050,7 +2199,7 @@ def restore_group_info_list(group_info_file_name):
|
|
|
2050
2199
|
|
|
2051
2200
|
Raises:
|
|
2052
2201
|
ValueError: group information file is incorrect.
|
|
2053
|
-
TypeError: group_info_file_name is not str.
|
|
2202
|
+
TypeError: `group_info_file_name` is not str.
|
|
2054
2203
|
|
|
2055
2204
|
Examples:
|
|
2056
2205
|
>>> import mindspore as ms
|
|
@@ -2072,9 +2221,6 @@ def restore_group_info_list(group_info_file_name):
|
|
|
2072
2221
|
def build_searched_strategy(strategy_filename):
|
|
2073
2222
|
"""
|
|
2074
2223
|
Build strategy of every parameter in network. Used in the case of distributed inference.
|
|
2075
|
-
For details of it, please check:
|
|
2076
|
-
`Saving and Loading Models in Hybrid Parallel Mode
|
|
2077
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
|
|
2078
2224
|
|
|
2079
2225
|
Args:
|
|
2080
2226
|
strategy_filename (str): Name of strategy file.
|
|
@@ -2096,8 +2242,6 @@ def build_searched_strategy(strategy_filename):
|
|
|
2096
2242
|
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
2097
2243
|
"""
|
|
2098
2244
|
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
2099
|
-
For details of it, please check:
|
|
2100
|
-
`<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/save_load.html>`_.
|
|
2101
2245
|
|
|
2102
2246
|
Args:
|
|
2103
2247
|
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
@@ -2191,9 +2335,6 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2191
2335
|
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
|
|
2192
2336
|
"""
|
|
2193
2337
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
2194
|
-
For details of distributed inference, please check:
|
|
2195
|
-
`Distributed Inference
|
|
2196
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/distributed_inference.html>`_ .
|
|
2197
2338
|
|
|
2198
2339
|
Args:
|
|
2199
2340
|
network (Cell): Network for distributed predication.
|
|
@@ -2218,6 +2359,104 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2218
2359
|
Raises:
|
|
2219
2360
|
TypeError: The type of inputs do not match the requirements.
|
|
2220
2361
|
ValueError: Failed to load checkpoint into net.
|
|
2362
|
+
|
|
2363
|
+
Supported Platforms:
|
|
2364
|
+
``Ascend`` ``GPU``
|
|
2365
|
+
|
|
2366
|
+
Examples:
|
|
2367
|
+
.. note::
|
|
2368
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2369
|
+
|
|
2370
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2371
|
+
Please see the `rank table startup
|
|
2372
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
2373
|
+
for more details.
|
|
2374
|
+
|
|
2375
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2376
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
2377
|
+
|
|
2378
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2379
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
2380
|
+
|
|
2381
|
+
>>> import os
|
|
2382
|
+
>>> import numpy as np
|
|
2383
|
+
>>> import mindspore as ms
|
|
2384
|
+
>>> import mindspore.dataset as ds
|
|
2385
|
+
>>> from mindspore import nn, ops, train
|
|
2386
|
+
>>> from mindspore.communication import init
|
|
2387
|
+
>>>
|
|
2388
|
+
>>> step_per_epoch = 4
|
|
2389
|
+
>>> device_num = 8
|
|
2390
|
+
>>>
|
|
2391
|
+
>>> # Define the network structure.
|
|
2392
|
+
>>> class Net(nn.Cell):
|
|
2393
|
+
... def __init__(self, matmul_size, strategy=None):
|
|
2394
|
+
... super().__init__()
|
|
2395
|
+
... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
|
|
2396
|
+
... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
|
|
2397
|
+
... self.matmul = ops.MatMul()
|
|
2398
|
+
... self.neg = ops.Neg()
|
|
2399
|
+
... if strategy is not None:
|
|
2400
|
+
... self.matmul.shard(strategy)
|
|
2401
|
+
...
|
|
2402
|
+
... def construct(self, inputs):
|
|
2403
|
+
... x = self.matmul(inputs, self.matmul_weight)
|
|
2404
|
+
... x = self.neg(x)
|
|
2405
|
+
... return x
|
|
2406
|
+
>>>
|
|
2407
|
+
>>> # Create dataset.
|
|
2408
|
+
>>> def get_dataset(*inputs):
|
|
2409
|
+
... def generate():
|
|
2410
|
+
... for _ in range(step_per_epoch):
|
|
2411
|
+
... yield inputs
|
|
2412
|
+
... return generate
|
|
2413
|
+
>>>
|
|
2414
|
+
>>> # Train network and save distributed checkpoint.
|
|
2415
|
+
>>> def train_net():
|
|
2416
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2417
|
+
... init()
|
|
2418
|
+
... np.random.seed(1)
|
|
2419
|
+
... input_data = np.random.rand(16, 96).astype(np.float32)
|
|
2420
|
+
... label_data = np.random.rand(16, 16).astype(np.float32)
|
|
2421
|
+
... fake_dataset = get_dataset(input_data, label_data)
|
|
2422
|
+
... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
|
2423
|
+
...
|
|
2424
|
+
... # Set parallel strategy.
|
|
2425
|
+
... strategy = ((1, 4), (4, 1))
|
|
2426
|
+
... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
|
|
2427
|
+
... strategy_ckpt_save_file="./train_strategy.ckpt")
|
|
2428
|
+
... network = Net(matmul_size=(96, 16), strategy=strategy)
|
|
2429
|
+
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
|
2430
|
+
... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
2431
|
+
... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
|
|
2432
|
+
... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
|
|
2433
|
+
... global_rank_id = int(os.getenv("RANK_ID"))
|
|
2434
|
+
... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
|
|
2435
|
+
... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
|
|
2436
|
+
... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
|
|
2437
|
+
... ms.reset_auto_parallel_context()
|
|
2438
|
+
>>>
|
|
2439
|
+
>>> # Load distributed checkpoint and test.
|
|
2440
|
+
>>> def load_model():
|
|
2441
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2442
|
+
... init()
|
|
2443
|
+
... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
|
|
2444
|
+
... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
|
|
2445
|
+
... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
|
|
2446
|
+
... network = Net(matmul_size=(96, 16))
|
|
2447
|
+
... model = ms.Model(network)
|
|
2448
|
+
... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
|
|
2449
|
+
... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
|
|
2450
|
+
... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
|
|
2451
|
+
... predict_result = model.predict(predict_data)
|
|
2452
|
+
... print(predict_result)
|
|
2453
|
+
>>>
|
|
2454
|
+
>>> train_net()
|
|
2455
|
+
>>> load_model()
|
|
2456
|
+
[[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
|
|
2457
|
+
[ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
|
|
2458
|
+
...
|
|
2459
|
+
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
2221
2460
|
"""
|
|
2222
2461
|
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
2223
2462
|
_check_checkpoint_file(checkpoint_filenames)
|
|
@@ -2395,7 +2634,8 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
2395
2634
|
return merged_param
|
|
2396
2635
|
param_name = merged_param.name
|
|
2397
2636
|
tensor_layout = predict_strategy[param_name]
|
|
2398
|
-
|
|
2637
|
+
rank = get_rank()
|
|
2638
|
+
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
|
|
2399
2639
|
requires_grad = merged_param.requires_grad
|
|
2400
2640
|
layerwise_parallel = merged_param.layerwise_parallel
|
|
2401
2641
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|