mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -21,9 +21,11 @@ from functools import wraps
|
|
|
21
21
|
import os
|
|
22
22
|
import math
|
|
23
23
|
import copy
|
|
24
|
+
import importlib
|
|
24
25
|
import numpy as np
|
|
25
26
|
|
|
26
27
|
import mindspore
|
|
28
|
+
import mindspore.dataset as ds
|
|
27
29
|
from mindspore import log as logger
|
|
28
30
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
|
29
31
|
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
|
|
@@ -37,7 +39,7 @@ from mindspore import context
|
|
|
37
39
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_parameter_broadcast, \
|
|
38
40
|
_device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
|
|
39
41
|
_reset_op_id_with_offset
|
|
40
|
-
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver,
|
|
42
|
+
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
|
|
41
43
|
_cache_enable, _enable_distributed_mindrt
|
|
42
44
|
from mindspore.train.metrics import Loss
|
|
43
45
|
from mindspore import nn
|
|
@@ -49,6 +51,7 @@ from mindspore.common.api import _pynative_executor
|
|
|
49
51
|
from mindspore.dataset.core.config import get_debug_mode
|
|
50
52
|
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
|
51
53
|
from mindspore.train import amp
|
|
54
|
+
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
def _transfer_tensor_to_tuple(inputs):
|
|
@@ -67,6 +70,17 @@ class _StepSync(Callback):
|
|
|
67
70
|
_pynative_executor.sync()
|
|
68
71
|
|
|
69
72
|
|
|
73
|
+
class _FrameworkProfilerCallback(Callback):
|
|
74
|
+
"""
|
|
75
|
+
Profiler callback of framework for training.
|
|
76
|
+
"""
|
|
77
|
+
def step_begin(self, run_context):
|
|
78
|
+
_framework_profiler_step_start()
|
|
79
|
+
|
|
80
|
+
def step_end(self, run_context):
|
|
81
|
+
_framework_profiler_step_end()
|
|
82
|
+
|
|
83
|
+
|
|
70
84
|
def _save_final_ckpt(func):
|
|
71
85
|
"""
|
|
72
86
|
Decorator function, which saves the current checkpoint when an exception occurs during training.
|
|
@@ -108,29 +122,33 @@ class Model:
|
|
|
108
122
|
`Model` groups layers into an object with training and inference features based on the arguments.
|
|
109
123
|
|
|
110
124
|
Note:
|
|
111
|
-
If use mixed precision functions, need to set parameter `optimizer` at the same time,
|
|
112
|
-
|
|
113
|
-
|
|
125
|
+
- If use mixed precision functions, need to set parameter `optimizer` at the same time,
|
|
126
|
+
otherwise mixed precision functions do not take effect.
|
|
127
|
+
When uses mixed precision functions, `global_step` in optimizer may be different from `cur_step_num`
|
|
128
|
+
in Model.
|
|
129
|
+
- After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
|
|
130
|
+
to perform the precision conversion again. If `Model` is used to train a converted network, `amp_level`
|
|
131
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
114
132
|
|
|
115
133
|
Args:
|
|
116
134
|
network (Cell): A training or testing network.
|
|
117
135
|
loss_fn (Cell): Objective function. If `loss_fn` is None, the `network` should contain the calculation of loss
|
|
118
|
-
and parallel if needed. Default: None.
|
|
136
|
+
and parallel if needed. Default: ``None`` .
|
|
119
137
|
optimizer (Cell): Optimizer for updating the weights. If `optimizer` is None, the `network` needs to
|
|
120
|
-
do backpropagation and update weights. Default
|
|
138
|
+
do backpropagation and update weights. Default: ``None`` .
|
|
121
139
|
metrics (Union[dict, set]): A Dictionary or a set of metrics for model evaluation.
|
|
122
|
-
eg: {'accuracy', 'recall'}. Default: None.
|
|
140
|
+
eg: {'accuracy', 'recall'}. Default: ``None`` .
|
|
123
141
|
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
|
|
124
|
-
`eval_network` . Default: None.
|
|
142
|
+
`eval_network` . Default: ``None`` .
|
|
125
143
|
eval_indexes (list): It is used when eval_network is defined. If `eval_indexes` is None by default, all outputs
|
|
126
144
|
of the `eval_network` would be passed to metrics. If `eval_indexes` is set, it must contain
|
|
127
145
|
three elements: the positions of loss value, predicted value and label in outputs of the
|
|
128
146
|
`eval_network`. In this case, the loss value will be passed to the `Loss` metric, the
|
|
129
147
|
predicted value and label will be passed to other metrics.
|
|
130
148
|
:func:`mindspore.train.Metric.set_indexes` is recommended instead of `eval_indexes`.
|
|
131
|
-
Default: None.
|
|
149
|
+
Default: ``None`` .
|
|
132
150
|
amp_level (str): Option for argument `level` in :func:`mindspore.amp.build_train_network`, level for mixed
|
|
133
|
-
precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
|
|
151
|
+
precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
|
|
134
152
|
|
|
135
153
|
- "O0": Do not change.
|
|
136
154
|
- "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
|
|
@@ -138,7 +156,7 @@ class Model:
|
|
|
138
156
|
Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
|
|
139
157
|
- "O2": Cast network to float16, keep BatchNorm run in float32, using dynamic loss scale.
|
|
140
158
|
- "O3": Cast network to float16, the BatchNorm is also cast to float16, loss scale will not be used.
|
|
141
|
-
- auto: Set level to recommended level in different devices. Set level to "O2" on GPU, set
|
|
159
|
+
- "auto": Set level to recommended level in different devices. Set level to "O2" on GPU, set
|
|
142
160
|
level to "O3" on Ascend. The recommended level is chosen by the expert experience, not applicable to all
|
|
143
161
|
scenarios. User should specify the level for special network.
|
|
144
162
|
|
|
@@ -149,7 +167,7 @@ class Model:
|
|
|
149
167
|
The more detailed explanation of `amp_level` setting can be found at `mindspore.amp.build_train_network`.
|
|
150
168
|
|
|
151
169
|
boost_level (str): Option for argument `level` in `mindspore.boost`, level for boost mode
|
|
152
|
-
training. Supports ["O0", "O1", "O2"]. Default: "O0".
|
|
170
|
+
training. Supports ["O0", "O1", "O2"]. Default: ``"O0"`` .
|
|
153
171
|
|
|
154
172
|
- "O0": Do not change.
|
|
155
173
|
- "O1": Enable the boost mode, the performance is improved by about 20%, and
|
|
@@ -165,39 +183,23 @@ class Model:
|
|
|
165
183
|
can obtain the same benefits. It is recommended to enable this function on
|
|
166
184
|
the Graph mode + Ascend platform, and for better acceleration, refer to the documentation to configure
|
|
167
185
|
boost_config_dict.
|
|
186
|
+
|
|
168
187
|
Examples:
|
|
169
188
|
>>> from mindspore import nn
|
|
170
189
|
>>> from mindspore.train import Model
|
|
171
190
|
>>>
|
|
172
|
-
>>>
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
|
177
|
-
... self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
|
|
178
|
-
... self.fc2 = nn.Dense(120, 84, weight_init='ones')
|
|
179
|
-
... self.fc3 = nn.Dense(84, num_class, weight_init='ones')
|
|
180
|
-
... self.relu = nn.ReLU()
|
|
181
|
-
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
182
|
-
... self.flatten = nn.Flatten()
|
|
183
|
-
...
|
|
184
|
-
... def construct(self, x):
|
|
185
|
-
... x = self.max_pool2d(self.relu(self.conv1(x)))
|
|
186
|
-
... x = self.max_pool2d(self.relu(self.conv2(x)))
|
|
187
|
-
... x = self.flatten(x)
|
|
188
|
-
... x = self.relu(self.fc1(x))
|
|
189
|
-
... x = self.relu(self.fc2(x))
|
|
190
|
-
... x = self.fc3(x)
|
|
191
|
-
... return x
|
|
192
|
-
>>>
|
|
193
|
-
>>> net = Net()
|
|
194
|
-
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
191
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
192
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
193
|
+
>>> net = LeNet5()
|
|
194
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
195
195
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
196
196
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
|
197
|
-
>>>
|
|
198
|
-
>>>
|
|
199
|
-
>>>
|
|
200
|
-
>>> dataset
|
|
197
|
+
>>> model.train_network
|
|
198
|
+
>>> model.predict_network
|
|
199
|
+
>>> model.eval_network
|
|
200
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
201
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
202
|
+
>>> dataset = create_dataset()
|
|
201
203
|
>>> model.train(2, dataset)
|
|
202
204
|
"""
|
|
203
205
|
|
|
@@ -223,6 +225,7 @@ class Model:
|
|
|
223
225
|
self._check_for_graph_cell(kwargs)
|
|
224
226
|
self._build_boost_network(kwargs)
|
|
225
227
|
self._train_network = self._build_train_network()
|
|
228
|
+
self._train_network._jit_config_dict = network.jit_config_dict
|
|
226
229
|
self._build_eval_network(metrics, self._eval_network, eval_indexes)
|
|
227
230
|
self._build_predict_network()
|
|
228
231
|
self._current_epoch_num = 0
|
|
@@ -231,6 +234,12 @@ class Model:
|
|
|
231
234
|
self.enable_recovery = False
|
|
232
235
|
self._backbone_is_train = True
|
|
233
236
|
self.need_load_ckpt = False
|
|
237
|
+
self._lite_full_predictor = None
|
|
238
|
+
self._lite_incremental_predictor = None
|
|
239
|
+
self._mindspore_lite = None
|
|
240
|
+
self._lite_infer = True # if backend lite infer fails, set False
|
|
241
|
+
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
242
|
+
|
|
234
243
|
|
|
235
244
|
def _check_for_graph_cell(self, kwargs):
|
|
236
245
|
"""Check for graph cell"""
|
|
@@ -458,7 +467,7 @@ class Model:
|
|
|
458
467
|
Args:
|
|
459
468
|
epoch (int): Total number of iterations on the data.
|
|
460
469
|
train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
|
|
461
|
-
initialized. Default: None
|
|
470
|
+
initialized. Default: ``None``.
|
|
462
471
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
463
472
|
"""
|
|
464
473
|
if sink_size == -1:
|
|
@@ -485,9 +494,9 @@ class Model:
|
|
|
485
494
|
|
|
486
495
|
Args:
|
|
487
496
|
train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
|
|
488
|
-
initialized. Default: None
|
|
497
|
+
initialized. Default: ``None``.
|
|
489
498
|
valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
|
|
490
|
-
will be initialized, and `metrics` in `Model` can not be None. Default: None
|
|
499
|
+
will be initialized, and `metrics` in `Model` can not be None. Default: ``None``.
|
|
491
500
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
492
501
|
epoch (int): Total number of iterations on the data. Default: 1.
|
|
493
502
|
"""
|
|
@@ -562,16 +571,15 @@ class Model:
|
|
|
562
571
|
returned and passed to the network. Otherwise, a tuple (data, label) will
|
|
563
572
|
be returned. The data and label would be passed to the network and loss
|
|
564
573
|
function respectively.
|
|
565
|
-
callbacks (list): List of callback objects which should be executed while training. Default: None
|
|
574
|
+
callbacks (list): List of callback objects which should be executed while training. Default: ``None``.
|
|
566
575
|
dataset_sink_mode (bool): Determine whether the data should be passed through the dataset channel.
|
|
567
|
-
Default: True
|
|
576
|
+
Default: ``True``.
|
|
568
577
|
Configure pynative mode or CPU, the training process will be performed with
|
|
569
578
|
dataset not sink.
|
|
570
579
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
571
580
|
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
572
581
|
Default: 0.
|
|
573
582
|
"""
|
|
574
|
-
epoch = Validator.check_positive_int(epoch)
|
|
575
583
|
if self._parameter_broadcast:
|
|
576
584
|
self._train_network.set_broadcast_flag()
|
|
577
585
|
|
|
@@ -590,15 +598,14 @@ class Model:
|
|
|
590
598
|
cb_params.train_dataset = train_dataset
|
|
591
599
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
592
600
|
valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
|
|
601
|
+
cb_params.list_callback.insert(0, _FrameworkProfilerCallback())
|
|
593
602
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
594
603
|
cb_params.list_callback.insert(0, _StepSync())
|
|
595
|
-
|
|
604
|
+
callbacks = cb_params.list_callback
|
|
596
605
|
cb_params.train_dataset_element = None
|
|
597
606
|
cb_params.network = self._network
|
|
598
|
-
if _is_role_sched():
|
|
599
|
-
epoch = 1
|
|
600
607
|
# Embedding cache server only run one step.
|
|
601
|
-
if
|
|
608
|
+
if _is_role_pserver() and _cache_enable():
|
|
602
609
|
epoch = 1
|
|
603
610
|
cb_params.last_save_ckpt_step = None
|
|
604
611
|
cb_params.latest_ckpt_file = None
|
|
@@ -632,18 +639,23 @@ class Model:
|
|
|
632
639
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
633
640
|
be returned. The data and label would be passed to the network and loss
|
|
634
641
|
function respectively.
|
|
635
|
-
list_callback (Callback): Executor of callback list. Default: None
|
|
636
|
-
cb_params (_InternalCallbackParam): Callback parameters. Default: None
|
|
642
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
643
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
637
644
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
638
645
|
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
639
646
|
Default: 0.
|
|
640
647
|
"""
|
|
641
648
|
is_graph = (context.get_context("mode") == context.GRAPH_MODE)
|
|
649
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
650
|
+
if dataset_size % sink_size != 0:
|
|
651
|
+
logger.warning("In dataset_sink mode (dataset_size % sink_size) should equal to 0, "
|
|
652
|
+
"it is suggested to pad/drop data or adjust sink_size. "
|
|
653
|
+
"But got 'dataset_size': {}, 'sink_size': {}.".format(dataset_size, sink_size))
|
|
642
654
|
if sink_size == -1:
|
|
643
|
-
|
|
655
|
+
dataset_sink_num = epoch
|
|
644
656
|
else:
|
|
645
|
-
|
|
646
|
-
train_dataset.__total_batch__ =
|
|
657
|
+
dataset_sink_num = math.ceil(epoch * sink_size / dataset_size)
|
|
658
|
+
train_dataset.__total_batch__ = epoch * sink_size
|
|
647
659
|
|
|
648
660
|
cb_params.cur_step_num = 0
|
|
649
661
|
cb_params.dataset_sink_mode = True
|
|
@@ -659,7 +671,7 @@ class Model:
|
|
|
659
671
|
|
|
660
672
|
self._check_enable_recovery()
|
|
661
673
|
# Used to check whether need perform recovery for process which is restarted.
|
|
662
|
-
self._check_need_load_ckpt(cb_params,
|
|
674
|
+
self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
|
|
663
675
|
# Check whether this process is embedding cache server.
|
|
664
676
|
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
665
677
|
|
|
@@ -672,10 +684,11 @@ class Model:
|
|
|
672
684
|
dataset=train_dataset,
|
|
673
685
|
dataset_sink_mode=True,
|
|
674
686
|
sink_size=sink_size,
|
|
675
|
-
epoch_num=
|
|
687
|
+
epoch_num=dataset_sink_num,
|
|
676
688
|
dataset_helper=dataset_helper)
|
|
677
689
|
|
|
678
690
|
cb_params.train_network = train_network
|
|
691
|
+
cb_params.dataset_helper = dataset_helper
|
|
679
692
|
|
|
680
693
|
# Perform recovery for process which is restarted.
|
|
681
694
|
self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
|
|
@@ -695,9 +708,6 @@ class Model:
|
|
|
695
708
|
outputs = train_network(*inputs)
|
|
696
709
|
cb_params.net_outputs = outputs
|
|
697
710
|
|
|
698
|
-
if _is_role_sched():
|
|
699
|
-
os._exit(0)
|
|
700
|
-
|
|
701
711
|
# In disaster recovery scenarios, need not to execute callbacks if this step executes failed.
|
|
702
712
|
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
|
703
713
|
if need_exec_callback_step_end:
|
|
@@ -824,7 +834,7 @@ class Model:
|
|
|
824
834
|
os.remove(cb_params.latest_ckpt_file)
|
|
825
835
|
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
826
836
|
+ cb_params.latest_ckpt_file) from e
|
|
827
|
-
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.
|
|
837
|
+
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
828
838
|
self.need_load_ckpt = False
|
|
829
839
|
|
|
830
840
|
def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
|
|
@@ -853,9 +863,9 @@ class Model:
|
|
|
853
863
|
self.epoch_iter = recovery_epoch_num
|
|
854
864
|
cb_params.cur_epoch_num = self.epoch_iter + 1
|
|
855
865
|
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
|
856
|
-
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.
|
|
866
|
+
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
857
867
|
else:
|
|
858
|
-
_reset_training_dataset(0, dataset_helper.
|
|
868
|
+
_reset_training_dataset(0, dataset_helper.iter.dataset.get_dataset_size())
|
|
859
869
|
|
|
860
870
|
_set_recovery_context(need_reset=False)
|
|
861
871
|
|
|
@@ -871,15 +881,15 @@ class Model:
|
|
|
871
881
|
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
872
882
|
be returned. The data and label would be passed to the network and loss
|
|
873
883
|
function respectively.
|
|
874
|
-
list_callback (Callback): Executor of callback list. Default: None
|
|
875
|
-
cb_params (_InternalCallbackParam): Callback parameters. Default: None
|
|
884
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
885
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
876
886
|
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
877
887
|
Default: 0.
|
|
878
888
|
"""
|
|
879
889
|
dataset_helper, _ = self._exec_preprocess(is_train=True,
|
|
880
890
|
dataset=train_dataset,
|
|
881
891
|
dataset_sink_mode=False,
|
|
882
|
-
epoch_num=
|
|
892
|
+
epoch_num=epoch)
|
|
883
893
|
cb_params.cur_step_num = 0
|
|
884
894
|
cb_params.dataset_sink_mode = False
|
|
885
895
|
run_context = RunContext(cb_params)
|
|
@@ -914,8 +924,6 @@ class Model:
|
|
|
914
924
|
self._loss_scale_manager.update_loss_scale(overflow)
|
|
915
925
|
|
|
916
926
|
list_callback.on_train_step_end(run_context)
|
|
917
|
-
if _is_role_sched():
|
|
918
|
-
os._exit(0)
|
|
919
927
|
# Embedding cache server only run one step.
|
|
920
928
|
if is_embedding_cache_server:
|
|
921
929
|
break
|
|
@@ -959,7 +967,7 @@ class Model:
|
|
|
959
967
|
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
|
960
968
|
|
|
961
969
|
When dataset_sink_mode is True, the `step_end` method of the instance of Callback will be called at the end
|
|
962
|
-
of epoch.
|
|
970
|
+
of step in PyNative mode, or will be called at the end of epoch in Graph mode.
|
|
963
971
|
|
|
964
972
|
If dataset_sink_mode is True, dataset will be bound to this model and cannot be used by other models.
|
|
965
973
|
|
|
@@ -983,12 +991,12 @@ class Model:
|
|
|
983
991
|
passed to the `network`.
|
|
984
992
|
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
|
|
985
993
|
which should be executed while training.
|
|
986
|
-
Default: None
|
|
994
|
+
Default: ``None``.
|
|
987
995
|
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
988
996
|
Configure pynative mode or CPU, the training process will be performed with
|
|
989
|
-
dataset not sink. Default: False
|
|
990
|
-
sink_size (int): Control the
|
|
991
|
-
is False.
|
|
997
|
+
dataset not sink. Default: ``False``.
|
|
998
|
+
sink_size (int): Control the number of steps for each sinking.
|
|
999
|
+
`sink_size` is invalid if `dataset_sink_mode` is False.
|
|
992
1000
|
If sink_size = -1, sink the complete dataset for each epoch.
|
|
993
1001
|
If sink_size > 0, sink sink_size data for each epoch.
|
|
994
1002
|
Default: -1.
|
|
@@ -999,17 +1007,21 @@ class Model:
|
|
|
999
1007
|
>>> from mindspore import nn
|
|
1000
1008
|
>>> from mindspore.train import Model
|
|
1001
1009
|
>>>
|
|
1002
|
-
>>> #
|
|
1003
|
-
>>> #
|
|
1004
|
-
>>> dataset =
|
|
1005
|
-
>>>
|
|
1006
|
-
>>>
|
|
1010
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1011
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
1012
|
+
>>> dataset = create_dataset()
|
|
1013
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1014
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1015
|
+
>>> net = LeNet5()
|
|
1016
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1007
1017
|
>>> loss_scale_manager = ms.FixedLossScaleManager(1024., False)
|
|
1008
1018
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1009
1019
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
1010
1020
|
... loss_scale_manager=loss_scale_manager)
|
|
1011
1021
|
>>> model.train(2, dataset)
|
|
1012
1022
|
"""
|
|
1023
|
+
# prepare dataset for obfuscated model
|
|
1024
|
+
train_dataset = self._prepare_obf_dataset(train_dataset)
|
|
1013
1025
|
device_target = context.get_context("device_target")
|
|
1014
1026
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1015
1027
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
@@ -1033,7 +1045,7 @@ class Model:
|
|
|
1033
1045
|
self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
|
|
1034
1046
|
|
|
1035
1047
|
Validator.check_is_int(sink_size)
|
|
1036
|
-
Validator.
|
|
1048
|
+
Validator.check_positive_int(epoch)
|
|
1037
1049
|
Validator.check_non_negative_int(initial_epoch)
|
|
1038
1050
|
if initial_epoch >= epoch:
|
|
1039
1051
|
raise ValueError(f"For 'Model.train', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
|
|
@@ -1121,42 +1133,48 @@ class Model:
|
|
|
1121
1133
|
then a tuple (data1, data2, data3, ...) with all data returned from dataset
|
|
1122
1134
|
will be passed to the `network`.
|
|
1123
1135
|
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
|
1124
|
-
will be performed on the end of training process. Default: None.
|
|
1136
|
+
will be performed on the end of training process. Default: ``None`` .
|
|
1125
1137
|
valid_frequency (int, list): Only relevant if `valid_dataset` is provided. If an integer, specifies
|
|
1126
1138
|
how many training epochs to run before a new validation run is performed,
|
|
1127
1139
|
e.g. `valid_frequency=2` runs validation every 2 epochs.
|
|
1128
1140
|
If a list, specifies the epochs on which to run validation,
|
|
1129
1141
|
e.g. `valid_frequency=[1, 5]` runs validation at the end of the 1st, 5th epochs.
|
|
1130
|
-
Default: 1
|
|
1142
|
+
Default: ``1`` .
|
|
1131
1143
|
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
|
|
1132
1144
|
which should be executed while training.
|
|
1133
|
-
Default: None.
|
|
1145
|
+
Default: ``None`` .
|
|
1134
1146
|
dataset_sink_mode (bool): Determines whether to pass the train data through dataset channel.
|
|
1135
1147
|
Configure pynative mode or CPU, the training process will be performed with
|
|
1136
|
-
dataset not sink. Default: False.
|
|
1148
|
+
dataset not sink. Default: ``False`` .
|
|
1137
1149
|
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
|
|
1138
|
-
Default: False.
|
|
1139
|
-
sink_size (int): Control the
|
|
1140
|
-
is False.
|
|
1150
|
+
Default: ``False`` .
|
|
1151
|
+
sink_size (int): Control the number of steps for each sinking.
|
|
1152
|
+
`sink_size` is invalid if `dataset_sink_mode` is False.
|
|
1141
1153
|
If sink_size = -1, sink the complete dataset for each epoch.
|
|
1142
1154
|
If sink_size > 0, sink sink_size data for each epoch.
|
|
1143
|
-
Default:
|
|
1155
|
+
Default: ``-1`` .
|
|
1144
1156
|
initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
|
|
1145
|
-
Default: 0.
|
|
1157
|
+
Default: ``0`` .
|
|
1146
1158
|
|
|
1147
1159
|
Examples:
|
|
1148
1160
|
>>> from mindspore import nn
|
|
1149
1161
|
>>> from mindspore.train import Model
|
|
1150
1162
|
>>>
|
|
1151
|
-
>>> #
|
|
1152
|
-
>>> #
|
|
1153
|
-
>>> train_dataset =
|
|
1154
|
-
>>> valid_dataset =
|
|
1155
|
-
>>>
|
|
1156
|
-
>>>
|
|
1163
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1164
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
1165
|
+
>>> train_dataset = create_dataset("train")
|
|
1166
|
+
>>> valid_dataset = create_dataset("test")
|
|
1167
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1168
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1169
|
+
>>> net = LeNet5()
|
|
1170
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1157
1171
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1158
1172
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
|
|
1159
1173
|
>>> model.fit(2, train_dataset, valid_dataset)
|
|
1174
|
+
|
|
1175
|
+
Tutorial Examples:
|
|
1176
|
+
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1177
|
+
<https://www.mindspore.cn/tutorials/en/r2.2/advanced/model.html#training-and-saving-model>`_
|
|
1160
1178
|
"""
|
|
1161
1179
|
device_target = context.get_context("device_target")
|
|
1162
1180
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
@@ -1175,7 +1193,7 @@ class Model:
|
|
|
1175
1193
|
.format(train_dataset._warmup_epoch, epoch))
|
|
1176
1194
|
|
|
1177
1195
|
Validator.check_is_int(sink_size)
|
|
1178
|
-
Validator.
|
|
1196
|
+
Validator.check_positive_int(epoch)
|
|
1179
1197
|
Validator.check_non_negative_int(initial_epoch)
|
|
1180
1198
|
if initial_epoch >= epoch:
|
|
1181
1199
|
raise ValueError(f"For 'Model.fit', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
|
|
@@ -1224,21 +1242,23 @@ class Model:
|
|
|
1224
1242
|
|
|
1225
1243
|
Args:
|
|
1226
1244
|
train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
|
|
1227
|
-
built. Default: None.
|
|
1245
|
+
built. Default: ``None`` .
|
|
1228
1246
|
valid_dataset (Dataset): An evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
|
|
1229
|
-
will be built, and `metrics` in `Model` can not be None. Default: None.
|
|
1230
|
-
sink_size (int): Control the
|
|
1231
|
-
epoch (int): Control the training epochs. Default: 1.
|
|
1247
|
+
will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
|
|
1248
|
+
sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
|
|
1249
|
+
epoch (int): Control the training epochs. Default: ``1`` .
|
|
1232
1250
|
|
|
1233
1251
|
Examples:
|
|
1234
1252
|
>>> from mindspore import nn
|
|
1235
1253
|
>>> from mindspore.train import Model
|
|
1236
1254
|
>>> from mindspore.amp import FixedLossScaleManager
|
|
1237
1255
|
>>>
|
|
1238
|
-
>>> #
|
|
1239
|
-
>>> #
|
|
1240
|
-
>>> dataset =
|
|
1241
|
-
>>>
|
|
1256
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1257
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
1258
|
+
>>> dataset = create_dataset()
|
|
1259
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1260
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1261
|
+
>>> net = LeNet5()
|
|
1242
1262
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
1243
1263
|
>>> loss_scale_manager = FixedLossScaleManager()
|
|
1244
1264
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
@@ -1247,6 +1267,10 @@ class Model:
|
|
|
1247
1267
|
>>> model.build(dataset, epoch=2)
|
|
1248
1268
|
>>> model.train(2, dataset)
|
|
1249
1269
|
"""
|
|
1270
|
+
epoch = Validator.check_positive_int(epoch)
|
|
1271
|
+
if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
|
|
1272
|
+
self._train_network.check_names_and_refresh_name()
|
|
1273
|
+
self._train_network._is_check_and_refresh = True
|
|
1250
1274
|
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
|
1251
1275
|
|
|
1252
1276
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
@@ -1255,12 +1279,12 @@ class Model:
|
|
|
1255
1279
|
|
|
1256
1280
|
Args:
|
|
1257
1281
|
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
|
1258
|
-
will be performed on the end of training process. Default: None
|
|
1282
|
+
will be performed on the end of training process. Default: ``None``.
|
|
1259
1283
|
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, which should be
|
|
1260
|
-
executed while evaluation. Default: None
|
|
1284
|
+
executed while evaluation. Default: ``None``.
|
|
1261
1285
|
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
|
|
1262
|
-
Default: True
|
|
1263
|
-
cb_params (_InternalCallbackParam): Callback parameters. Default: None
|
|
1286
|
+
Default: ``True``.
|
|
1287
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
1264
1288
|
"""
|
|
1265
1289
|
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
|
1266
1290
|
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
|
@@ -1289,8 +1313,8 @@ class Model:
|
|
|
1289
1313
|
|
|
1290
1314
|
Args:
|
|
1291
1315
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
1292
|
-
list_callback (Callback): Executor of callback list. Default: None
|
|
1293
|
-
cb_params (_InternalCallbackParam): Callback parameters. Default: None
|
|
1316
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
1317
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
1294
1318
|
|
|
1295
1319
|
Returns:
|
|
1296
1320
|
Dict, which returns the loss value and metrics values for the model in the test mode.
|
|
@@ -1313,8 +1337,6 @@ class Model:
|
|
|
1313
1337
|
outputs = eval_network(*inputs)
|
|
1314
1338
|
cb_params.net_outputs = outputs
|
|
1315
1339
|
list_callback.on_eval_step_end(run_context)
|
|
1316
|
-
if _is_role_sched():
|
|
1317
|
-
os._exit(0)
|
|
1318
1340
|
self._update_metrics(outputs)
|
|
1319
1341
|
if add_eval_loss:
|
|
1320
1342
|
eval_loss_fn = get_metric_fn("loss")
|
|
@@ -1337,8 +1359,8 @@ class Model:
|
|
|
1337
1359
|
|
|
1338
1360
|
Args:
|
|
1339
1361
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
1340
|
-
list_callback (Callback): Executor of callback list. Default: None
|
|
1341
|
-
cb_params (_InternalCallbackParam): Callback parameters. Default: None
|
|
1362
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
1363
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
1342
1364
|
|
|
1343
1365
|
Returns:
|
|
1344
1366
|
Dict, which returns the loss value and metrics values for the model in the test mode.
|
|
@@ -1359,8 +1381,6 @@ class Model:
|
|
|
1359
1381
|
outputs = self._eval_network(*next_element)
|
|
1360
1382
|
cb_params.net_outputs = outputs
|
|
1361
1383
|
list_callback.on_eval_step_end(run_context)
|
|
1362
|
-
if _is_role_sched():
|
|
1363
|
-
os._exit(0)
|
|
1364
1384
|
self._update_metrics(outputs)
|
|
1365
1385
|
if add_eval_loss:
|
|
1366
1386
|
eval_loss_fn = get_metric_fn("loss")
|
|
@@ -1397,9 +1417,9 @@ class Model:
|
|
|
1397
1417
|
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
1398
1418
|
callbacks (Optional[list(Callback), Callback]): List of callback objects or callback object,
|
|
1399
1419
|
which should be executed while evaluation.
|
|
1400
|
-
Default: None.
|
|
1420
|
+
Default: ``None`` .
|
|
1401
1421
|
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
1402
|
-
Default: False.
|
|
1422
|
+
Default: ``False`` .
|
|
1403
1423
|
|
|
1404
1424
|
Returns:
|
|
1405
1425
|
Dict, the key is the metric name defined by users and the value is the metrics value for
|
|
@@ -1409,14 +1429,21 @@ class Model:
|
|
|
1409
1429
|
>>> from mindspore import nn
|
|
1410
1430
|
>>> from mindspore.train import Model
|
|
1411
1431
|
>>>
|
|
1412
|
-
>>> #
|
|
1413
|
-
>>> #
|
|
1414
|
-
>>> dataset =
|
|
1415
|
-
>>>
|
|
1416
|
-
>>>
|
|
1432
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1433
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
1434
|
+
>>> dataset = create_dataset()
|
|
1435
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1436
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1437
|
+
>>> net = LeNet5()
|
|
1438
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1417
1439
|
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
|
1418
1440
|
>>> acc = model.eval(dataset, dataset_sink_mode=False)
|
|
1441
|
+
|
|
1442
|
+
Tutorial Examples:
|
|
1443
|
+
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1444
|
+
<https://www.mindspore.cn/tutorials/en/r2.2/advanced/model.html#training-and-saving-model>`_
|
|
1419
1445
|
"""
|
|
1446
|
+
valid_dataset = self._prepare_obf_dataset(valid_dataset)
|
|
1420
1447
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1421
1448
|
|
|
1422
1449
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
@@ -1464,7 +1491,140 @@ class Model:
|
|
|
1464
1491
|
|
|
1465
1492
|
return eval_result
|
|
1466
1493
|
|
|
1467
|
-
def
|
|
1494
|
+
def _predict_lite(self, *predict_data, config=None):
|
|
1495
|
+
"""
|
|
1496
|
+
Generate output predictions for the input samples using backend 'lite'.
|
|
1497
|
+
|
|
1498
|
+
Args:
|
|
1499
|
+
predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
|
|
1500
|
+
The predict data, can be a single tensor,
|
|
1501
|
+
a list of tensor, or a tuple of tensor.
|
|
1502
|
+
|
|
1503
|
+
config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
|
|
1504
|
+
The config includes two parts: config_path (configPath, str) and config_item (str, dict).
|
|
1505
|
+
When the config_item is set, its priority is higher than the config_path. Set the ranking
|
|
1506
|
+
table file for inference. The content of the configuration file is as follows:
|
|
1507
|
+
|
|
1508
|
+
config_path defines the path of the configuration file, which is used to pass user-defined
|
|
1509
|
+
options during model building. In the following scenarios, users may need to set parameters.
|
|
1510
|
+
For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
|
|
1511
|
+
config.ini file:
|
|
1512
|
+
|
|
1513
|
+
.. code-block::
|
|
1514
|
+
|
|
1515
|
+
[ascend_context]
|
|
1516
|
+
rank_table_file = [path_a](storage initial path of the rank table file)
|
|
1517
|
+
[execution_plan]
|
|
1518
|
+
[op_name1] = data_type:float16 (operator named op_name1 is set to data type Float16)
|
|
1519
|
+
[op_name2] = data_type:float32 (operator named op_name2 is set to data type Float32)
|
|
1520
|
+
|
|
1521
|
+
When only the config_path is configured, it is done as follows:
|
|
1522
|
+
|
|
1523
|
+
.. code-block::
|
|
1524
|
+
|
|
1525
|
+
config = {"configPath" : "/home/user/config.ini"}
|
|
1526
|
+
|
|
1527
|
+
When only the config_dict is configured, it is done as follows:
|
|
1528
|
+
|
|
1529
|
+
.. code-block::
|
|
1530
|
+
|
|
1531
|
+
config = {"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1532
|
+
"execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
|
|
1533
|
+
|
|
1534
|
+
When both the `config_path` and the `config_dict` are configured, it is done as follows:
|
|
1535
|
+
|
|
1536
|
+
.. code-block::
|
|
1537
|
+
|
|
1538
|
+
config = {"configPath" : "/home/user/config.ini",
|
|
1539
|
+
"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1540
|
+
"execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
|
|
1541
|
+
|
|
1542
|
+
Note that both the "configPath" is configured in the config_dict and the config_item,
|
|
1543
|
+
in this case, the path_b in the config_dict takes precedence.
|
|
1544
|
+
|
|
1545
|
+
Returns:
|
|
1546
|
+
Tensor, array(s) of predictions.
|
|
1547
|
+
"""
|
|
1548
|
+
def _get_lite_context(lite_context_input):
|
|
1549
|
+
# use default lite context parameters for now
|
|
1550
|
+
device_target = context.get_context("device_target").lower()
|
|
1551
|
+
lite_context_input.target = [device_target]
|
|
1552
|
+
if device_target == 'cpu':
|
|
1553
|
+
inter_op_parallel_num = context.get_context('inter_op_parallel_num')
|
|
1554
|
+
if inter_op_parallel_num and isinstance(inter_op_parallel_num, int):
|
|
1555
|
+
lite_context_input.cpu.inter_op_parallel_num = inter_op_parallel_num
|
|
1556
|
+
elif device_target == 'gpu':
|
|
1557
|
+
device_id = context.get_context('device_id')
|
|
1558
|
+
if device_id and isinstance(device_id, int):
|
|
1559
|
+
lite_context_input.gpu.device_id = device_id
|
|
1560
|
+
if context.get_auto_parallel_context("parallel_mode") == context.ParallelMode.SEMI_AUTO_PARALLEL:
|
|
1561
|
+
from mindspore.communication import init, get_rank
|
|
1562
|
+
init()
|
|
1563
|
+
lite_context_input.gpu.rank_id = get_rank()
|
|
1564
|
+
elif device_target == 'ascend':
|
|
1565
|
+
device_id = context.get_context('device_id')
|
|
1566
|
+
if device_id and isinstance(device_id, int):
|
|
1567
|
+
lite_context_input.ascend.device_id = device_id
|
|
1568
|
+
if context.get_auto_parallel_context("parallel_mode") == context.ParallelMode.SEMI_AUTO_PARALLEL:
|
|
1569
|
+
from mindspore.communication import init, get_rank
|
|
1570
|
+
init()
|
|
1571
|
+
lite_context_input.ascend.rank_id = get_rank()
|
|
1572
|
+
lite_context_input.ascend.provider = "ge"
|
|
1573
|
+
else:
|
|
1574
|
+
raise RuntimeError(f"For predict lite, device target should be in ['gpu', 'cpu', 'ascend']"
|
|
1575
|
+
f" but got {device_target}")
|
|
1576
|
+
return lite_context_input
|
|
1577
|
+
|
|
1578
|
+
if not self._mindspore_lite:
|
|
1579
|
+
self._mindspore_lite = importlib.import_module('mindspore_lite')
|
|
1580
|
+
|
|
1581
|
+
use_past = False # default execute full model inference
|
|
1582
|
+
model_group_id = None
|
|
1583
|
+
if self._predict_network.get_flags().__contains__("is_first_iteration"):
|
|
1584
|
+
is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
|
|
1585
|
+
if isinstance(is_first_iteration, bool):
|
|
1586
|
+
use_past = not is_first_iteration
|
|
1587
|
+
model_group_id = self._mindspore_lite_model_group_id
|
|
1588
|
+
|
|
1589
|
+
check_input_data(*predict_data, data_class=Tensor)
|
|
1590
|
+
if use_past:
|
|
1591
|
+
# Execute incremental model inference
|
|
1592
|
+
if not self._lite_incremental_predictor:
|
|
1593
|
+
lite_context = _get_lite_context(self._mindspore_lite.Context())
|
|
1594
|
+
self._lite_incremental_predictor = \
|
|
1595
|
+
self._mindspore_lite.lite_infer.LiteInfer(self, *predict_data, context=lite_context,
|
|
1596
|
+
model_group_id=model_group_id, config=config)
|
|
1597
|
+
|
|
1598
|
+
inputs = self._lite_incremental_predictor.get_inputs()
|
|
1599
|
+
if len(predict_data) != len(inputs):
|
|
1600
|
+
raise RuntimeError(f"For 'Model.predict', numbers of predict_data {len(predict_data)} "
|
|
1601
|
+
f"is not equal to numbers of net input {len(inputs)}")
|
|
1602
|
+
for i, single_data in enumerate(predict_data):
|
|
1603
|
+
inputs[i].set_data_from_numpy(single_data.asnumpy())
|
|
1604
|
+
outputs: list = self._lite_incremental_predictor.predict(inputs)
|
|
1605
|
+
else:
|
|
1606
|
+
# Execute full model inference
|
|
1607
|
+
if not self._lite_full_predictor:
|
|
1608
|
+
lite_context = _get_lite_context(self._mindspore_lite.Context())
|
|
1609
|
+
self._lite_full_predictor = \
|
|
1610
|
+
self._mindspore_lite.lite_infer.LiteInfer(self, *predict_data, context=lite_context,
|
|
1611
|
+
model_group_id=model_group_id, config=config)
|
|
1612
|
+
|
|
1613
|
+
inputs = self._lite_full_predictor.get_inputs()
|
|
1614
|
+
if len(predict_data) != len(inputs):
|
|
1615
|
+
raise RuntimeError(f"For 'Model.predict', numbers of predict_data {len(predict_data)} "
|
|
1616
|
+
f"is not equal to numbers of net input {len(inputs)}")
|
|
1617
|
+
for i, single_data in enumerate(predict_data):
|
|
1618
|
+
inputs[i].set_data_from_numpy(single_data.asnumpy())
|
|
1619
|
+
outputs: list = self._lite_full_predictor.predict(inputs)
|
|
1620
|
+
if not outputs:
|
|
1621
|
+
return Tensor(outputs)
|
|
1622
|
+
if len(outputs) == 1:
|
|
1623
|
+
return Tensor(outputs[0].get_data_to_numpy())
|
|
1624
|
+
outputs = [Tensor(single_output.get_data_to_numpy()) for single_output in outputs]
|
|
1625
|
+
return tuple(outputs)
|
|
1626
|
+
|
|
1627
|
+
def predict(self, *predict_data, backend=None, config=None):
|
|
1468
1628
|
"""
|
|
1469
1629
|
Generate output predictions for the input samples.
|
|
1470
1630
|
|
|
@@ -1472,6 +1632,49 @@ class Model:
|
|
|
1472
1632
|
predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
|
|
1473
1633
|
The predict data, can be a single tensor,
|
|
1474
1634
|
a list of tensor, or a tuple of tensor.
|
|
1635
|
+
backend (str): Select predict backend, this parameter is an experimental feature
|
|
1636
|
+
and is mainly used for MindSpore Lite cloud-side inference. Default: ``None`` .
|
|
1637
|
+
config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
|
|
1638
|
+
The config includes two parts: config_path (configPath, str) and config_item (str, dict).
|
|
1639
|
+
When the config_item is set, its priority is higher than the config_path. Set the ranking
|
|
1640
|
+
table file for inference. The content of the configuration file is as follows:
|
|
1641
|
+
|
|
1642
|
+
config_path defines the path of the configuration file, which is used to pass user-defined
|
|
1643
|
+
options during model building. In the following scenarios, users may need to set parameters.
|
|
1644
|
+
For example: "/home/user/config.ini". Default value: ``""`` , here is the content of the
|
|
1645
|
+
config.ini file:
|
|
1646
|
+
|
|
1647
|
+
.. code-block::
|
|
1648
|
+
|
|
1649
|
+
[ascend_context]
|
|
1650
|
+
rank_table_file = [path_a](storage initial path of the rank table file)
|
|
1651
|
+
[execution_plan]
|
|
1652
|
+
[op_name1] = data_type:float16 (operator named op_name1 is set to data type Float16)
|
|
1653
|
+
[op_name2] = data_type:float32 (operator named op_name2 is set to data type Float32)
|
|
1654
|
+
|
|
1655
|
+
When only the config_path is configured, it is done as follows:
|
|
1656
|
+
|
|
1657
|
+
.. code-block::
|
|
1658
|
+
|
|
1659
|
+
config = {"configPath" : "/home/user/config.ini"}
|
|
1660
|
+
|
|
1661
|
+
When only the config_dict is configured, it is done as follows:
|
|
1662
|
+
|
|
1663
|
+
.. code-block::
|
|
1664
|
+
|
|
1665
|
+
config = {"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1666
|
+
"execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
|
|
1667
|
+
|
|
1668
|
+
When both the `config_path` and the `config_dict` are configured, it is done as follows:
|
|
1669
|
+
|
|
1670
|
+
.. code-block::
|
|
1671
|
+
|
|
1672
|
+
config = {"configPath" : "/home/user/config.ini",
|
|
1673
|
+
"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1674
|
+
"execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
|
|
1675
|
+
|
|
1676
|
+
Note that both the "configPath" is configured in the config_dict and the config_item,
|
|
1677
|
+
in this case, the path_b in the config_dict takes precedence.
|
|
1475
1678
|
|
|
1476
1679
|
Returns:
|
|
1477
1680
|
Tensor, array(s) of predictions.
|
|
@@ -1483,9 +1686,27 @@ class Model:
|
|
|
1483
1686
|
>>> from mindspore.train import Model
|
|
1484
1687
|
>>>
|
|
1485
1688
|
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), mindspore.float32)
|
|
1486
|
-
>>>
|
|
1689
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1690
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1691
|
+
>>> model = Model(LeNet5())
|
|
1487
1692
|
>>> result = model.predict(input_data)
|
|
1488
1693
|
"""
|
|
1694
|
+
if backend not in ['lite', None]:
|
|
1695
|
+
raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
|
|
1696
|
+
if backend == "lite" and self._lite_infer:
|
|
1697
|
+
# pylint: disable=broad-except
|
|
1698
|
+
try:
|
|
1699
|
+
return self._predict_lite(*predict_data, config=config)
|
|
1700
|
+
except RuntimeError:
|
|
1701
|
+
self._lite_infer = False
|
|
1702
|
+
logger.warning("Lite inference failed, fallback to original inference!")
|
|
1703
|
+
except ImportError:
|
|
1704
|
+
self._lite_infer = False
|
|
1705
|
+
logger.warning("Import mindspore_lite failed, fallback to original inference!")
|
|
1706
|
+
except BaseException as e:
|
|
1707
|
+
self._lite_infer = False
|
|
1708
|
+
logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
|
|
1709
|
+
|
|
1489
1710
|
self._check_network_mode(self._predict_network, False)
|
|
1490
1711
|
check_input_data(*predict_data, data_class=(int, float, str, None, Tensor))
|
|
1491
1712
|
_parallel_predict_check()
|
|
@@ -1550,12 +1771,12 @@ class Model:
|
|
|
1550
1771
|
function respectively.
|
|
1551
1772
|
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
1552
1773
|
Configure pynative mode or CPU, the training process will be performed with
|
|
1553
|
-
dataset not sink. Default: True.
|
|
1554
|
-
sink_size (int): Control the
|
|
1774
|
+
dataset not sink. Default: ``True`` .
|
|
1775
|
+
sink_size (int): Control the number of steps for each sinking.
|
|
1555
1776
|
If sink_size = -1, sink the complete dataset for each epoch.
|
|
1556
1777
|
If sink_size > 0, sink sink_size data for each epoch.
|
|
1557
1778
|
If dataset_sink_mode is False, set sink_size as invalid.
|
|
1558
|
-
Default:
|
|
1779
|
+
Default: ``-1`` .
|
|
1559
1780
|
|
|
1560
1781
|
Returns:
|
|
1561
1782
|
Dict, Parameter layout dictionary used for load distributed checkpoint
|
|
@@ -1573,10 +1794,12 @@ class Model:
|
|
|
1573
1794
|
>>> init()
|
|
1574
1795
|
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
1575
1796
|
>>>
|
|
1576
|
-
>>> #
|
|
1577
|
-
>>> #
|
|
1578
|
-
>>> dataset =
|
|
1579
|
-
>>>
|
|
1797
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1798
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
1799
|
+
>>> dataset = create_dataset()
|
|
1800
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1801
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1802
|
+
>>> net = LeNet5()
|
|
1580
1803
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
1581
1804
|
>>> loss_scale_manager = ms.FixedLossScaleManager()
|
|
1582
1805
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
@@ -1598,7 +1821,7 @@ class Model:
|
|
|
1598
1821
|
return train_network.parameter_layout_dict
|
|
1599
1822
|
|
|
1600
1823
|
|
|
1601
|
-
def infer_predict_layout(self, *predict_data):
|
|
1824
|
+
def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
|
|
1602
1825
|
"""
|
|
1603
1826
|
Generate parameter layout for the predict network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
|
|
1604
1827
|
|
|
@@ -1611,6 +1834,9 @@ class Model:
|
|
|
1611
1834
|
predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
|
|
1612
1835
|
The predict data, can be a single tensor,
|
|
1613
1836
|
a list of tensor, or a tuple of tensor.
|
|
1837
|
+
skip_backend_compile (bool): Only run the frontend compile process,
|
|
1838
|
+
skip the compile process on the device side. Set this flag to True may
|
|
1839
|
+
lead to recompiling process can not hit cache.
|
|
1614
1840
|
|
|
1615
1841
|
Returns:
|
|
1616
1842
|
Dict, Parameter layout dictionary used for load distributed checkpoint.
|
|
@@ -1646,7 +1872,14 @@ class Model:
|
|
|
1646
1872
|
predict_net = self._predict_network
|
|
1647
1873
|
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
|
|
1648
1874
|
predict_net = self._check_network_mode(predict_net, False)
|
|
1649
|
-
|
|
1875
|
+
if skip_backend_compile:
|
|
1876
|
+
origin_phase = predict_net.phase
|
|
1877
|
+
predict_net.phase = "export." + predict_net.phase
|
|
1878
|
+
predict_net.compile(*predict_data)
|
|
1879
|
+
# set phase back to prevent from hitting incomplete compile cache
|
|
1880
|
+
predict_net.phase = origin_phase
|
|
1881
|
+
else:
|
|
1882
|
+
predict_net.compile(*predict_data)
|
|
1650
1883
|
return predict_net.parameter_layout_dict
|
|
1651
1884
|
|
|
1652
1885
|
def _flush_from_cache(self, cb_params):
|
|
@@ -1686,5 +1919,16 @@ class Model:
|
|
|
1686
1919
|
"""
|
|
1687
1920
|
return self._eval_network
|
|
1688
1921
|
|
|
1922
|
+
def _prepare_obf_dataset(self, dataset):
|
|
1923
|
+
if not hasattr(self._network, 'obf_ratios'):
|
|
1924
|
+
return dataset
|
|
1925
|
+
data_size = dataset.get_dataset_size()
|
|
1926
|
+
obf_ratio_dataset = []
|
|
1927
|
+
for _ in range(data_size):
|
|
1928
|
+
obf_ratio_dataset.append(self._network.obf_ratios)
|
|
1929
|
+
obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
|
|
1930
|
+
dataset = ds.zip((dataset, obf_ratio_dataset))
|
|
1931
|
+
return dataset
|
|
1932
|
+
|
|
1689
1933
|
|
|
1690
1934
|
__all__ = ["Model"]
|