mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -33,8 +33,7 @@ from mindspore import context
|
|
|
33
33
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
34
34
|
from mindspore import _checkparam as Validator
|
|
35
35
|
from mindspore.common import dtype as mstype
|
|
36
|
-
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
-
_AutoIdentifyDynamicShape
|
|
36
|
+
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
38
37
|
from mindspore.common.api import _generate_branch_control_input
|
|
39
38
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
40
39
|
from mindspore.common.tensor import Tensor
|
|
@@ -65,6 +64,15 @@ class Cell(Cell_):
|
|
|
65
64
|
graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
|
|
66
65
|
PYNATIVE_MODE (dynamic graph mode).
|
|
67
66
|
|
|
67
|
+
.. note::
|
|
68
|
+
Cell is the inference mode by default. For a class that inherits a Cell,
|
|
69
|
+
if the training and inference have different structures, the subclass performs the inference branch by default.
|
|
70
|
+
To set the training mode, refer to `mindspore.nn.Cell.set_train` .
|
|
71
|
+
|
|
72
|
+
.. warning::
|
|
73
|
+
In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
|
|
74
|
+
named 'phase' or 'cells', otherwise, an error will be raised.
|
|
75
|
+
|
|
68
76
|
Args:
|
|
69
77
|
auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
|
|
70
78
|
affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
|
|
@@ -156,11 +164,9 @@ class Cell(Cell_):
|
|
|
156
164
|
self.saved_dynamic_shape = None
|
|
157
165
|
self._jit_config_dict = dict()
|
|
158
166
|
self.grad_ops_label = False
|
|
159
|
-
self.to_float_fp16 = False
|
|
160
|
-
self.ge_init = False
|
|
161
167
|
self.ge_sync_data = False
|
|
162
|
-
self.
|
|
163
|
-
|
|
168
|
+
self._is_check_and_refresh = False
|
|
169
|
+
self._amp_level = ""
|
|
164
170
|
|
|
165
171
|
def __getstate__(self):
|
|
166
172
|
base = Cell_.__getstate__(self)
|
|
@@ -192,6 +198,23 @@ class Cell(Cell_):
|
|
|
192
198
|
def param_prefix(self):
|
|
193
199
|
"""
|
|
194
200
|
Param prefix is the prefix of current cell's direct child parameter.
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
>>> import mindspore as ms
|
|
204
|
+
>>> from mindspore import Tensor, nn
|
|
205
|
+
...
|
|
206
|
+
>>> class Net(nn.Cell):
|
|
207
|
+
... def __init__(self):
|
|
208
|
+
... super(Net, self).__init__()
|
|
209
|
+
... self.dense = nn.Dense(2, 2)
|
|
210
|
+
...
|
|
211
|
+
... def construct(self, x):
|
|
212
|
+
... x = self.dense(x)
|
|
213
|
+
... return x
|
|
214
|
+
>>> net = Net()
|
|
215
|
+
>>> net.update_cell_prefix()
|
|
216
|
+
>>> print(net.dense.param_prefix)
|
|
217
|
+
dense
|
|
195
218
|
"""
|
|
196
219
|
return self._param_prefix
|
|
197
220
|
|
|
@@ -202,7 +225,7 @@ class Cell(Cell_):
|
|
|
202
225
|
|
|
203
226
|
Tutorial Examples:
|
|
204
227
|
- `Cell and Parameter - Custom Cell Reverse
|
|
205
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
228
|
+
<https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#custom-cell-reverse>`_
|
|
206
229
|
"""
|
|
207
230
|
return self._bprop_debug
|
|
208
231
|
|
|
@@ -309,6 +332,21 @@ class Cell(Cell_):
|
|
|
309
332
|
for item in self.trainable_params():
|
|
310
333
|
item.add_pipeline_stage(value)
|
|
311
334
|
|
|
335
|
+
@property
|
|
336
|
+
def pipeline_segment(self):
|
|
337
|
+
return self._pipeline_segment
|
|
338
|
+
|
|
339
|
+
@pipeline_segment.setter
|
|
340
|
+
def pipeline_segment(self, value):
|
|
341
|
+
if not isinstance(value, int) or isinstance(value, bool):
|
|
342
|
+
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
343
|
+
"must be int type, but got type : {}".format(type(value)))
|
|
344
|
+
|
|
345
|
+
if value < 0:
|
|
346
|
+
raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
347
|
+
"can not be less than 0, but got {}".format(value))
|
|
348
|
+
self._pipeline_segment = value
|
|
349
|
+
|
|
312
350
|
@property
|
|
313
351
|
def parallel_parameter_merge_net_dict(self):
|
|
314
352
|
return self._parallel_parameter_merge_net_dict
|
|
@@ -345,7 +383,7 @@ class Cell(Cell_):
|
|
|
345
383
|
if '_params_list' in self.__dict__:
|
|
346
384
|
params_list = self.__dict__['_params_list']
|
|
347
385
|
if name in params_list:
|
|
348
|
-
return
|
|
386
|
+
return params_list[name]
|
|
349
387
|
raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
|
350
388
|
|
|
351
389
|
def __del__(self):
|
|
@@ -365,11 +403,11 @@ class Cell(Cell_):
|
|
|
365
403
|
del self._params[name]
|
|
366
404
|
elif name in self._cells:
|
|
367
405
|
del self._cells[name]
|
|
406
|
+
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
407
|
+
del self._params_list[name]
|
|
408
|
+
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
409
|
+
del self._tensor_list[name]
|
|
368
410
|
else:
|
|
369
|
-
if '_params_list' in self.__dict__ and name in self._params_list:
|
|
370
|
-
del self._params_list[name]
|
|
371
|
-
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
372
|
-
del self._tensor_list[name]
|
|
373
411
|
object.__delattr__(self, name)
|
|
374
412
|
self._attr_synced = False
|
|
375
413
|
|
|
@@ -381,8 +419,8 @@ class Cell(Cell_):
|
|
|
381
419
|
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
|
382
420
|
elif isinstance(item, float):
|
|
383
421
|
res.append(self.cast(item, dst_type))
|
|
384
|
-
elif hasattr(item, "dtype") and item.dtype in
|
|
385
|
-
item.dtype != dst_type:
|
|
422
|
+
elif hasattr(item, "dtype") and item.dtype in \
|
|
423
|
+
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
386
424
|
res.append(self.cast(item, dst_type))
|
|
387
425
|
else:
|
|
388
426
|
res.append(item)
|
|
@@ -629,7 +667,10 @@ class Cell(Cell_):
|
|
|
629
667
|
if PackFunc.is_tracing():
|
|
630
668
|
return self._run_tracefunc(*args, **kwargs)
|
|
631
669
|
|
|
632
|
-
self.
|
|
670
|
+
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
|
|
671
|
+
self.check_names_and_refresh_name()
|
|
672
|
+
self._is_check_and_refresh = True
|
|
673
|
+
|
|
633
674
|
# Run in Graph mode.
|
|
634
675
|
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
635
676
|
self._check_construct_args(*args)
|
|
@@ -886,14 +927,14 @@ class Cell(Cell_):
|
|
|
886
927
|
>>> import mindspore as ms
|
|
887
928
|
>>> from mindspore import nn, Tensor
|
|
888
929
|
>>>
|
|
889
|
-
>>> class
|
|
930
|
+
>>> class ReluNet(nn.Cell):
|
|
890
931
|
... def __init__(self):
|
|
891
|
-
... super(
|
|
932
|
+
... super(ReluNet, self).__init__()
|
|
892
933
|
... self.relu = nn.ReLU()
|
|
893
934
|
... def construct(self, x):
|
|
894
935
|
... return self.relu(x)
|
|
895
936
|
>>>
|
|
896
|
-
>>> net =
|
|
937
|
+
>>> net = ReluNet()
|
|
897
938
|
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
898
939
|
>>> net.set_inputs(input_dyn)
|
|
899
940
|
>>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
|
|
@@ -902,13 +943,10 @@ class Cell(Cell_):
|
|
|
902
943
|
if self.grad_ops_label:
|
|
903
944
|
logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
|
|
904
945
|
f'generated.')
|
|
905
|
-
for ele in inputs:
|
|
906
|
-
if isinstance(ele, str):
|
|
907
|
-
raise TypeError(f"For element in 'set_inputs', the type must not be str.")
|
|
908
946
|
self._dynamic_shape_inputs = inputs
|
|
909
947
|
self._check_construct_args(*inputs)
|
|
910
948
|
if context._get_mode() == context.PYNATIVE_MODE:
|
|
911
|
-
_pynative_executor.set_dynamic_input(self)
|
|
949
|
+
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
912
950
|
|
|
913
951
|
def get_inputs(self):
|
|
914
952
|
"""
|
|
@@ -919,6 +957,26 @@ class Cell(Cell_):
|
|
|
919
957
|
|
|
920
958
|
.. warning::
|
|
921
959
|
This is an experimental API that is subject to change or deletion.
|
|
960
|
+
|
|
961
|
+
Examples:
|
|
962
|
+
>>> import numpy as np
|
|
963
|
+
>>> import mindspore as ms
|
|
964
|
+
>>> from mindspore import nn, Tensor
|
|
965
|
+
>>>
|
|
966
|
+
>>> class ReluNet(nn.Cell):
|
|
967
|
+
... def __init__(self):
|
|
968
|
+
... super(ReluNet, self).__init__()
|
|
969
|
+
... self.relu = nn.ReLU()
|
|
970
|
+
... def construct(self, x):
|
|
971
|
+
... return self.relu(x)
|
|
972
|
+
>>>
|
|
973
|
+
>>> net = ReluNet()
|
|
974
|
+
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
975
|
+
>>> net.set_inputs(input_dyn)
|
|
976
|
+
>>> get_inputs = net.get_inputs()
|
|
977
|
+
>>> print(get_inputs)
|
|
978
|
+
(Tensor(shape=[3, -1], dtype=Float32, value= ),)
|
|
979
|
+
|
|
922
980
|
"""
|
|
923
981
|
|
|
924
982
|
return self._dynamic_shape_inputs
|
|
@@ -936,9 +994,8 @@ class Cell(Cell_):
|
|
|
936
994
|
self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
|
|
937
995
|
|
|
938
996
|
if self._dynamic_shape_inputs is None:
|
|
939
|
-
compile_args = self.auto_identify_dynamic_shape.auto_dynamic_generate_compile_args(args)
|
|
940
997
|
_cell_graph_executor.compile(self, phase=self.phase,
|
|
941
|
-
jit_config_dict=self._jit_config_dict, *
|
|
998
|
+
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
942
999
|
else:
|
|
943
1000
|
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
944
1001
|
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
@@ -994,6 +1051,23 @@ class Cell(Cell_):
|
|
|
994
1051
|
Raises:
|
|
995
1052
|
KeyError: If the name of parameter is null or contains dot.
|
|
996
1053
|
TypeError: If the type of parameter is not Parameter.
|
|
1054
|
+
|
|
1055
|
+
Examples:
|
|
1056
|
+
>>> import mindspore as ms
|
|
1057
|
+
>>> from mindspore import Tensor, nn, Parameter
|
|
1058
|
+
...
|
|
1059
|
+
>>> class Net(nn.Cell):
|
|
1060
|
+
... def __init__(self):
|
|
1061
|
+
... super(Net, self).__init__()
|
|
1062
|
+
... self.relu = nn.ReLU()
|
|
1063
|
+
...
|
|
1064
|
+
... def construct(self, x):
|
|
1065
|
+
... x = self.relu(x)
|
|
1066
|
+
... return x
|
|
1067
|
+
>>> net = Net()
|
|
1068
|
+
>>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
|
|
1069
|
+
>>> print(net.bias)
|
|
1070
|
+
Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
|
|
997
1071
|
"""
|
|
998
1072
|
if not param_name:
|
|
999
1073
|
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
@@ -1007,6 +1081,9 @@ class Cell(Cell_):
|
|
|
1007
1081
|
if not isinstance(param, Parameter) and param is not None:
|
|
1008
1082
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1009
1083
|
f"but got {type(param)}.")
|
|
1084
|
+
if param is None:
|
|
1085
|
+
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
|
|
1086
|
+
f"but got None.")
|
|
1010
1087
|
if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
|
|
1011
1088
|
param.name = param_name
|
|
1012
1089
|
self._params[param_name] = param
|
|
@@ -1048,6 +1125,18 @@ class Cell(Cell_):
|
|
|
1048
1125
|
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
|
1049
1126
|
TypeError: If type of `child_name` is not str.
|
|
1050
1127
|
TypeError: Child Cell's type is incorrect.
|
|
1128
|
+
|
|
1129
|
+
Examples:
|
|
1130
|
+
>>> import mindspore as ms
|
|
1131
|
+
>>> from mindspore import Tensor, nn
|
|
1132
|
+
...
|
|
1133
|
+
>>> net1 = nn.ReLU()
|
|
1134
|
+
>>> net2 = nn.Dense(2, 2)
|
|
1135
|
+
>>> net1.insert_child_to_cell("child", net2)
|
|
1136
|
+
>>> print(net1)
|
|
1137
|
+
ReLU<
|
|
1138
|
+
(child): Dense<input_channels=2, output_channels=2, has_bias=True>
|
|
1139
|
+
>
|
|
1051
1140
|
"""
|
|
1052
1141
|
if not isinstance(child_name, str):
|
|
1053
1142
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
@@ -1118,6 +1207,25 @@ class Cell(Cell_):
|
|
|
1118
1207
|
|
|
1119
1208
|
Returns:
|
|
1120
1209
|
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
|
|
1210
|
+
|
|
1211
|
+
Examples:
|
|
1212
|
+
>>> import mindspore as ms
|
|
1213
|
+
>>> from mindspore import Tensor, nn
|
|
1214
|
+
...
|
|
1215
|
+
>>> class Net(nn.Cell):
|
|
1216
|
+
... def __init__(self):
|
|
1217
|
+
... super(Net, self).__init__()
|
|
1218
|
+
... self.dense = nn.Dense(2, 2)
|
|
1219
|
+
...
|
|
1220
|
+
... def construct(self, x):
|
|
1221
|
+
... x = self.dense(x)
|
|
1222
|
+
... return x
|
|
1223
|
+
>>> net = Net()
|
|
1224
|
+
>>> print(net.init_parameters_data())
|
|
1225
|
+
{Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
|
|
1226
|
+
Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
|
|
1227
|
+
Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
|
|
1228
|
+
Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
|
|
1121
1229
|
"""
|
|
1122
1230
|
replace = dict()
|
|
1123
1231
|
|
|
@@ -1163,6 +1271,24 @@ class Cell(Cell_):
|
|
|
1163
1271
|
|
|
1164
1272
|
Returns:
|
|
1165
1273
|
OrderedDict, return parameters dictionary.
|
|
1274
|
+
|
|
1275
|
+
Examples:
|
|
1276
|
+
>>> import mindspore as ms
|
|
1277
|
+
>>> from mindspore import Tensor, nn, Parameter
|
|
1278
|
+
...
|
|
1279
|
+
>>> class Net(nn.Cell):
|
|
1280
|
+
... def __init__(self):
|
|
1281
|
+
... super(Net, self).__init__()
|
|
1282
|
+
... self.dense = nn.Dense(2, 2)
|
|
1283
|
+
...
|
|
1284
|
+
... def construct(self, x):
|
|
1285
|
+
... x = self.dense(x)
|
|
1286
|
+
... return x
|
|
1287
|
+
>>> net = Net()
|
|
1288
|
+
>>> print(net.parameters_dict())
|
|
1289
|
+
OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
|
|
1290
|
+
requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
|
|
1291
|
+
requires_grad=True))])
|
|
1166
1292
|
"""
|
|
1167
1293
|
param_dict = OrderedDict()
|
|
1168
1294
|
for param in self.get_parameters(expand=recurse):
|
|
@@ -1238,7 +1364,7 @@ class Cell(Cell_):
|
|
|
1238
1364
|
|
|
1239
1365
|
Tutorial Examples:
|
|
1240
1366
|
- `Model Training - Optimizer
|
|
1241
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1367
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#optimizer>`_
|
|
1242
1368
|
"""
|
|
1243
1369
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1244
1370
|
|
|
@@ -1263,6 +1389,7 @@ class Cell(Cell_):
|
|
|
1263
1389
|
Returns an iterator over cell parameters.
|
|
1264
1390
|
|
|
1265
1391
|
Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
|
|
1392
|
+
For more details about subcells, please see the example below.
|
|
1266
1393
|
|
|
1267
1394
|
Args:
|
|
1268
1395
|
expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
|
|
@@ -1272,11 +1399,34 @@ class Cell(Cell_):
|
|
|
1272
1399
|
Iteration, all parameters at the cell.
|
|
1273
1400
|
|
|
1274
1401
|
Examples:
|
|
1275
|
-
>>>
|
|
1276
|
-
>>>
|
|
1277
|
-
>>>
|
|
1278
|
-
>>>
|
|
1279
|
-
...
|
|
1402
|
+
>>> import mindspore as ms
|
|
1403
|
+
>>> from mindspore import nn, ops, Tensor
|
|
1404
|
+
>>> import numpy as np
|
|
1405
|
+
>>> class TestNet(nn.Cell):
|
|
1406
|
+
... def __init__(self):
|
|
1407
|
+
... super().__init__()
|
|
1408
|
+
... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
|
|
1409
|
+
... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
|
|
1410
|
+
... def construct(self, x):
|
|
1411
|
+
... x += self.my_w1
|
|
1412
|
+
... x = ops.reshape(x, (16,)) - self.my_w2
|
|
1413
|
+
... return x
|
|
1414
|
+
>>> class TestNet2(nn.Cell):
|
|
1415
|
+
... def __init__(self):
|
|
1416
|
+
... super().__init__()
|
|
1417
|
+
... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
|
|
1418
|
+
... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
|
|
1419
|
+
... # also be gathered.
|
|
1420
|
+
... self.subcell = TestNet()
|
|
1421
|
+
... def construct(self, x):
|
|
1422
|
+
... x += self.my_w1
|
|
1423
|
+
... x = ops.reshape(x, (16,)) - self.my_w2
|
|
1424
|
+
... return x
|
|
1425
|
+
>>> net = TestNet2()
|
|
1426
|
+
>>> print([p for p in net.get_parameters(expand=True)])
|
|
1427
|
+
[Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
|
|
1428
|
+
shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
|
|
1429
|
+
requires_grad=True)]
|
|
1280
1430
|
"""
|
|
1281
1431
|
for _, param in self.parameters_and_names(expand=expand):
|
|
1282
1432
|
yield param
|
|
@@ -1325,7 +1475,7 @@ class Cell(Cell_):
|
|
|
1325
1475
|
|
|
1326
1476
|
Tutorial Examples:
|
|
1327
1477
|
- `Building a Network - Model Parameters
|
|
1328
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1478
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/model.html#model-parameters>`_
|
|
1329
1479
|
"""
|
|
1330
1480
|
cells = []
|
|
1331
1481
|
if expand:
|
|
@@ -1337,7 +1487,7 @@ class Cell(Cell_):
|
|
|
1337
1487
|
for cell_name, cell in cells:
|
|
1338
1488
|
params = cell._params.items()
|
|
1339
1489
|
for par_name, par in params:
|
|
1340
|
-
if par.inited_param is not None:
|
|
1490
|
+
if par is not None and par.inited_param is not None:
|
|
1341
1491
|
par = par.inited_param
|
|
1342
1492
|
if par is not None and id(par) not in params_set:
|
|
1343
1493
|
params_set.add(id(par))
|
|
@@ -1394,6 +1544,22 @@ class Cell(Cell_):
|
|
|
1394
1544
|
|
|
1395
1545
|
Returns:
|
|
1396
1546
|
Iteration, the immediate cells in the cell.
|
|
1547
|
+
|
|
1548
|
+
Examples:
|
|
1549
|
+
>>> import mindspore as ms
|
|
1550
|
+
>>> from mindspore import Tensor, nn
|
|
1551
|
+
...
|
|
1552
|
+
>>> class Net(nn.Cell):
|
|
1553
|
+
... def __init__(self):
|
|
1554
|
+
... super(Net, self).__init__()
|
|
1555
|
+
... self.dense = nn.Dense(2, 2)
|
|
1556
|
+
...
|
|
1557
|
+
... def construct(self, x):
|
|
1558
|
+
... x = self.dense(x)
|
|
1559
|
+
... return x
|
|
1560
|
+
>>> net = Net()
|
|
1561
|
+
>>> print(net.cells())
|
|
1562
|
+
odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
|
|
1397
1563
|
"""
|
|
1398
1564
|
return self.name_cells().values()
|
|
1399
1565
|
|
|
@@ -1439,6 +1605,22 @@ class Cell(Cell_):
|
|
|
1439
1605
|
|
|
1440
1606
|
Returns:
|
|
1441
1607
|
Dict, all the child cells and corresponding names in the cell.
|
|
1608
|
+
|
|
1609
|
+
Examples:
|
|
1610
|
+
>>> import mindspore as ms
|
|
1611
|
+
>>> from mindspore import Tensor, nn
|
|
1612
|
+
...
|
|
1613
|
+
>>> class Net(nn.Cell):
|
|
1614
|
+
... def __init__(self):
|
|
1615
|
+
... super(Net, self).__init__()
|
|
1616
|
+
... self.dense = nn.Dense(2, 2)
|
|
1617
|
+
...
|
|
1618
|
+
... def construct(self, x):
|
|
1619
|
+
... x = self.dense(x)
|
|
1620
|
+
... return x
|
|
1621
|
+
>>> net = Net()
|
|
1622
|
+
>>> print(net.name_cells())
|
|
1623
|
+
OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
|
|
1442
1624
|
"""
|
|
1443
1625
|
value_set = set()
|
|
1444
1626
|
cells = OrderedDict()
|
|
@@ -1454,13 +1636,8 @@ class Cell(Cell_):
|
|
|
1454
1636
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
|
1455
1637
|
if "fp32" in flags and flags.get("fp32", False):
|
|
1456
1638
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
"""Add mixed precision flag to each cell"""
|
|
1460
|
-
if "fp16" in flags and flags.get("fp16", False):
|
|
1461
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
|
1462
|
-
if "fp32" in flags and flags.get("fp32", False):
|
|
1463
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1639
|
+
if "bf16" in flags and flags.get("bf16", False):
|
|
1640
|
+
Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
|
|
1464
1641
|
|
|
1465
1642
|
def apply(self, fn):
|
|
1466
1643
|
"""
|
|
@@ -1503,6 +1680,23 @@ class Cell(Cell_):
|
|
|
1503
1680
|
Args:
|
|
1504
1681
|
flags (dict): Network configuration information, currently it is used for the binding of network and
|
|
1505
1682
|
dataset. Users can also customize network attributes by this parameter.
|
|
1683
|
+
|
|
1684
|
+
Examples:
|
|
1685
|
+
>>> import mindspore as ms
|
|
1686
|
+
>>> from mindspore import Tensor, nn
|
|
1687
|
+
...
|
|
1688
|
+
>>> class Net(nn.Cell):
|
|
1689
|
+
... def __init__(self):
|
|
1690
|
+
... super(Net, self).__init__()
|
|
1691
|
+
... self.relu = nn.ReLU()
|
|
1692
|
+
...
|
|
1693
|
+
... def construct(self, x):
|
|
1694
|
+
... x = self.relu(x)
|
|
1695
|
+
... return x
|
|
1696
|
+
>>> net = Net()
|
|
1697
|
+
>>> net.add_flags(sink_mode=True)
|
|
1698
|
+
>>> print(net.sink_mode)
|
|
1699
|
+
True
|
|
1506
1700
|
"""
|
|
1507
1701
|
if not hasattr(self, "_func_graph_flags"):
|
|
1508
1702
|
self._func_graph_flags = {}
|
|
@@ -1518,9 +1712,25 @@ class Cell(Cell_):
|
|
|
1518
1712
|
Args:
|
|
1519
1713
|
flags (dict): Network configuration information, currently it is used for the binding of network and
|
|
1520
1714
|
dataset. Users can also customize network attributes by this parameter.
|
|
1715
|
+
|
|
1716
|
+
Examples:
|
|
1717
|
+
>>> import mindspore as ms
|
|
1718
|
+
>>> from mindspore import Tensor, nn
|
|
1719
|
+
...
|
|
1720
|
+
>>> class Net(nn.Cell):
|
|
1721
|
+
... def __init__(self):
|
|
1722
|
+
... super(Net, self).__init__()
|
|
1723
|
+
... self.relu = nn.ReLU()
|
|
1724
|
+
...
|
|
1725
|
+
... def construct(self, x):
|
|
1726
|
+
... x = self.relu(x)
|
|
1727
|
+
... return x
|
|
1728
|
+
>>> net = Net()
|
|
1729
|
+
>>> net.add_flags_recursive(sink_mode=True)
|
|
1730
|
+
>>> print(net.sink_mode)
|
|
1731
|
+
True
|
|
1521
1732
|
"""
|
|
1522
1733
|
self.add_flags(**flags)
|
|
1523
|
-
self._add_mixed_precision_flag_recursive(**flags)
|
|
1524
1734
|
for cell in self.cells():
|
|
1525
1735
|
cell.add_flags_recursive(**flags)
|
|
1526
1736
|
return self
|
|
@@ -1532,17 +1742,28 @@ class Cell(Cell_):
|
|
|
1532
1742
|
def get_flags(self):
|
|
1533
1743
|
"""
|
|
1534
1744
|
Get the self_defined attributes of the cell, which can be added by `add_flags` method.
|
|
1745
|
+
|
|
1746
|
+
Examples:
|
|
1747
|
+
>>> import mindspore as ms
|
|
1748
|
+
>>> from mindspore import Tensor, nn
|
|
1749
|
+
...
|
|
1750
|
+
>>> class Net(nn.Cell):
|
|
1751
|
+
... def __init__(self):
|
|
1752
|
+
... super(Net, self).__init__()
|
|
1753
|
+
... self.relu = nn.ReLU()
|
|
1754
|
+
...
|
|
1755
|
+
... def construct(self, x):
|
|
1756
|
+
... x = self.relu(x)
|
|
1757
|
+
... return x
|
|
1758
|
+
>>> net = Net()
|
|
1759
|
+
>>> net.add_flags(sink_mode=True)
|
|
1760
|
+
>>> print(net.get_flags())
|
|
1761
|
+
{'sink_mode':True}
|
|
1535
1762
|
"""
|
|
1536
1763
|
if not hasattr(self, "_func_graph_flags"):
|
|
1537
1764
|
self._func_graph_flags = {}
|
|
1538
1765
|
return self._func_graph_flags
|
|
1539
1766
|
|
|
1540
|
-
def _set_mixed_precision_type_recursive(self, mixed_type):
|
|
1541
|
-
"""Set mixed precision type to each cell"""
|
|
1542
|
-
Cell_.set_mixed_precision_type(self, mixed_type)
|
|
1543
|
-
for cell in self.cells():
|
|
1544
|
-
cell._set_mixed_precision_type_recursive(mixed_type)
|
|
1545
|
-
|
|
1546
1767
|
def to_float(self, dst_type):
|
|
1547
1768
|
"""
|
|
1548
1769
|
Add cast on all inputs of cell and child cells to run with certain float type.
|
|
@@ -1555,13 +1776,13 @@ class Cell(Cell_):
|
|
|
1555
1776
|
|
|
1556
1777
|
Args:
|
|
1557
1778
|
dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
|
|
1558
|
-
dst_type can be `mstype.float16` or `mstype.
|
|
1779
|
+
dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
|
|
1559
1780
|
|
|
1560
1781
|
Returns:
|
|
1561
1782
|
Cell, the cell itself.
|
|
1562
1783
|
|
|
1563
1784
|
Raises:
|
|
1564
|
-
ValueError: If dst_type is not mstype.float32
|
|
1785
|
+
ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
|
|
1565
1786
|
|
|
1566
1787
|
Supported Platforms:
|
|
1567
1788
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1573,19 +1794,15 @@ class Cell(Cell_):
|
|
|
1573
1794
|
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
|
1574
1795
|
>>> net.to_float(mstype.float16)
|
|
1575
1796
|
Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
|
|
1576
|
-
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=
|
|
1577
|
-
"""
|
|
1578
|
-
if dst_type not in (mstype.float16, mstype.float32):
|
|
1579
|
-
raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32
|
|
1580
|
-
"but got type: {} and value: {}.".format(type(dst_type), dst_type))
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
self.to_float_fp16 = True
|
|
1584
|
-
else:
|
|
1585
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1586
|
-
self.to_float_fp16 = False
|
|
1587
|
-
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
|
|
1797
|
+
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
|
|
1798
|
+
"""
|
|
1799
|
+
if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
|
|
1800
|
+
raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
|
|
1801
|
+
"mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
|
|
1802
|
+
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
|
|
1803
|
+
'bf16': dst_type == mstype.bfloat16}
|
|
1588
1804
|
self._add_init_args(**flags)
|
|
1805
|
+
self.add_flags_recursive(**flags)
|
|
1589
1806
|
return self
|
|
1590
1807
|
|
|
1591
1808
|
def set_boost(self, boost_type):
|
|
@@ -1594,7 +1811,7 @@ class Cell(Cell_):
|
|
|
1594
1811
|
accelerate the algorithm in the algorithm library.
|
|
1595
1812
|
|
|
1596
1813
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1597
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.
|
|
1814
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.2/mindspore/python/mindspore/boost>`_.
|
|
1598
1815
|
|
|
1599
1816
|
Note:
|
|
1600
1817
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1651,12 +1868,12 @@ class Cell(Cell_):
|
|
|
1651
1868
|
|
|
1652
1869
|
Tutorial Examples:
|
|
1653
1870
|
- `Model Training - Implementing Training and Evaluation
|
|
1654
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1871
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#training-and-evaluation>`_
|
|
1655
1872
|
"""
|
|
1656
|
-
if mode
|
|
1657
|
-
self._phase = 'predict'
|
|
1658
|
-
else:
|
|
1873
|
+
if mode:
|
|
1659
1874
|
self._phase = 'train'
|
|
1875
|
+
else:
|
|
1876
|
+
self._phase = 'predict'
|
|
1660
1877
|
self.add_flags_recursive(training=mode)
|
|
1661
1878
|
return self
|
|
1662
1879
|
|
|
@@ -1685,16 +1902,27 @@ class Cell(Cell_):
|
|
|
1685
1902
|
|
|
1686
1903
|
Args:
|
|
1687
1904
|
jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
|
|
1905
|
+
|
|
1906
|
+
Examples:
|
|
1907
|
+
>>> import mindspore as ms
|
|
1908
|
+
>>> from mindspore import Tensor, nn
|
|
1909
|
+
...
|
|
1910
|
+
>>> class Net(nn.Cell):
|
|
1911
|
+
... def __init__(self):
|
|
1912
|
+
... super(Net, self).__init__()
|
|
1913
|
+
... self.relu = nn.ReLU()
|
|
1914
|
+
...
|
|
1915
|
+
... def construct(self, x):
|
|
1916
|
+
... x = self.relu(x)
|
|
1917
|
+
... return x
|
|
1918
|
+
>>> net = Net()
|
|
1919
|
+
>>> jitconfig = ms.JitConfig()
|
|
1920
|
+
>>> net.set_jit_config(jitconfig)
|
|
1688
1921
|
"""
|
|
1689
1922
|
if self._jit_config_dict:
|
|
1690
1923
|
logger.warning("For Cell, jit config can only be set once, ignore this setting.")
|
|
1691
1924
|
else:
|
|
1692
1925
|
self._jit_config_dict = jit_config.jit_config_dict
|
|
1693
|
-
enable_ge = os.getenv("MS_ENABLE_GE") == '1'
|
|
1694
|
-
enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
|
|
1695
|
-
if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
|
|
1696
|
-
raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
|
|
1697
|
-
format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
|
1698
1926
|
|
|
1699
1927
|
def flatten_weights(self, fusion_size=0):
|
|
1700
1928
|
"""
|
|
@@ -2290,12 +2518,13 @@ class Cell(Cell_):
|
|
|
2290
2518
|
def _run_tracefunc(self, *args, **kwargs):
|
|
2291
2519
|
""" Run Packed Cell in Pack."""
|
|
2292
2520
|
args = self._mixed_precision_cast(args)
|
|
2293
|
-
|
|
2521
|
+
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
|
2522
|
+
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
|
2294
2523
|
expander = PackExpander.get_instance()
|
|
2295
2524
|
args = expander.begin_subgraph(self, *args)
|
|
2296
2525
|
args = [_convert_tensor(a) for a in args]
|
|
2297
2526
|
output = self._run_construct(args, kwargs)
|
|
2298
|
-
ret = expander.end_subgraph(output)
|
|
2527
|
+
ret = expander.end_subgraph(self, output)
|
|
2299
2528
|
output = _convert_tensor(ret)
|
|
2300
2529
|
else:
|
|
2301
2530
|
with _SetMixedPrecision(self):
|
|
@@ -2306,10 +2535,23 @@ class Cell(Cell_):
|
|
|
2306
2535
|
mixed_type = self.get_mixed_precision_type()
|
|
2307
2536
|
if mixed_type == MixedPrecisionType.NOTSET:
|
|
2308
2537
|
return inputs
|
|
2309
|
-
|
|
2538
|
+
if mixed_type == MixedPrecisionType.FP16:
|
|
2539
|
+
cast_type = mstype.float16
|
|
2540
|
+
elif mixed_type == MixedPrecisionType.BF16:
|
|
2541
|
+
cast_type = mstype.bfloat16
|
|
2542
|
+
else:
|
|
2543
|
+
cast_type = mstype.float32
|
|
2310
2544
|
cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
|
|
2311
2545
|
return cast_inputs
|
|
2312
2546
|
|
|
2547
|
+
def _get_attr_from_cell(self, network):
|
|
2548
|
+
if not isinstance(network, Cell):
|
|
2549
|
+
return
|
|
2550
|
+
if hasattr(network, "jit_config_dict"):
|
|
2551
|
+
self._jit_config_dict = network.jit_config_dict
|
|
2552
|
+
if hasattr(network, "_amp_level"):
|
|
2553
|
+
self._amp_level = getattr(network, "_amp_level")
|
|
2554
|
+
|
|
2313
2555
|
|
|
2314
2556
|
class GraphCell(Cell):
|
|
2315
2557
|
"""
|