mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -33,8 +33,7 @@ from mindspore import context
|
|
|
33
33
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
34
34
|
from mindspore import _checkparam as Validator
|
|
35
35
|
from mindspore.common import dtype as mstype
|
|
36
|
-
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
-
_AutoIdentifyDynamicShape
|
|
36
|
+
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
38
37
|
from mindspore.common.api import _generate_branch_control_input
|
|
39
38
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
40
39
|
from mindspore.common.tensor import Tensor
|
|
@@ -65,6 +64,15 @@ class Cell(Cell_):
|
|
|
65
64
|
graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
|
|
66
65
|
PYNATIVE_MODE (dynamic graph mode).
|
|
67
66
|
|
|
67
|
+
.. note::
|
|
68
|
+
Cell is the inference mode by default. For a class that inherits a Cell,
|
|
69
|
+
if the training and inference have different structures, the subclass performs the inference branch by default.
|
|
70
|
+
To set the training mode, refer to `mindspore.nn.Cell.set_train` .
|
|
71
|
+
|
|
72
|
+
.. warning::
|
|
73
|
+
In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
|
|
74
|
+
named 'phase' or 'cells', otherwise, an error will be raised.
|
|
75
|
+
|
|
68
76
|
Args:
|
|
69
77
|
auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
|
|
70
78
|
affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
|
|
@@ -156,11 +164,9 @@ class Cell(Cell_):
|
|
|
156
164
|
self.saved_dynamic_shape = None
|
|
157
165
|
self._jit_config_dict = dict()
|
|
158
166
|
self.grad_ops_label = False
|
|
159
|
-
self.to_float_fp16 = False
|
|
160
|
-
self.ge_init = False
|
|
161
167
|
self.ge_sync_data = False
|
|
162
|
-
self.
|
|
163
|
-
|
|
168
|
+
self._is_check_and_refresh = False
|
|
169
|
+
self._amp_level = ""
|
|
164
170
|
|
|
165
171
|
def __getstate__(self):
|
|
166
172
|
base = Cell_.__getstate__(self)
|
|
@@ -192,6 +198,23 @@ class Cell(Cell_):
|
|
|
192
198
|
def param_prefix(self):
|
|
193
199
|
"""
|
|
194
200
|
Param prefix is the prefix of current cell's direct child parameter.
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
>>> import mindspore as ms
|
|
204
|
+
>>> from mindspore import Tensor, nn
|
|
205
|
+
...
|
|
206
|
+
>>> class Net(nn.Cell):
|
|
207
|
+
... def __init__(self):
|
|
208
|
+
... super(Net, self).__init__()
|
|
209
|
+
... self.dense = nn.Dense(2, 2)
|
|
210
|
+
...
|
|
211
|
+
... def construct(self, x):
|
|
212
|
+
... x = self.dense(x)
|
|
213
|
+
... return x
|
|
214
|
+
>>> net = Net()
|
|
215
|
+
>>> net.update_cell_prefix()
|
|
216
|
+
>>> print(net.dense.param_prefix)
|
|
217
|
+
dense
|
|
195
218
|
"""
|
|
196
219
|
return self._param_prefix
|
|
197
220
|
|
|
@@ -202,7 +225,7 @@ class Cell(Cell_):
|
|
|
202
225
|
|
|
203
226
|
Tutorial Examples:
|
|
204
227
|
- `Cell and Parameter - Custom Cell Reverse
|
|
205
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
228
|
+
<https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#custom-cell-reverse>`_
|
|
206
229
|
"""
|
|
207
230
|
return self._bprop_debug
|
|
208
231
|
|
|
@@ -309,6 +332,21 @@ class Cell(Cell_):
|
|
|
309
332
|
for item in self.trainable_params():
|
|
310
333
|
item.add_pipeline_stage(value)
|
|
311
334
|
|
|
335
|
+
@property
|
|
336
|
+
def pipeline_segment(self):
|
|
337
|
+
return self._pipeline_segment
|
|
338
|
+
|
|
339
|
+
@pipeline_segment.setter
|
|
340
|
+
def pipeline_segment(self, value):
|
|
341
|
+
if not isinstance(value, int) or isinstance(value, bool):
|
|
342
|
+
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
343
|
+
"must be int type, but got type : {}".format(type(value)))
|
|
344
|
+
|
|
345
|
+
if value < 0:
|
|
346
|
+
raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
347
|
+
"can not be less than 0, but got {}".format(value))
|
|
348
|
+
self._pipeline_segment = value
|
|
349
|
+
|
|
312
350
|
@property
|
|
313
351
|
def parallel_parameter_merge_net_dict(self):
|
|
314
352
|
return self._parallel_parameter_merge_net_dict
|
|
@@ -345,7 +383,7 @@ class Cell(Cell_):
|
|
|
345
383
|
if '_params_list' in self.__dict__:
|
|
346
384
|
params_list = self.__dict__['_params_list']
|
|
347
385
|
if name in params_list:
|
|
348
|
-
return
|
|
386
|
+
return params_list[name]
|
|
349
387
|
raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
|
|
350
388
|
|
|
351
389
|
def __del__(self):
|
|
@@ -365,11 +403,11 @@ class Cell(Cell_):
|
|
|
365
403
|
del self._params[name]
|
|
366
404
|
elif name in self._cells:
|
|
367
405
|
del self._cells[name]
|
|
406
|
+
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
407
|
+
del self._params_list[name]
|
|
408
|
+
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
409
|
+
del self._tensor_list[name]
|
|
368
410
|
else:
|
|
369
|
-
if '_params_list' in self.__dict__ and name in self._params_list:
|
|
370
|
-
del self._params_list[name]
|
|
371
|
-
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
372
|
-
del self._tensor_list[name]
|
|
373
411
|
object.__delattr__(self, name)
|
|
374
412
|
self._attr_synced = False
|
|
375
413
|
|
|
@@ -381,8 +419,8 @@ class Cell(Cell_):
|
|
|
381
419
|
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
|
382
420
|
elif isinstance(item, float):
|
|
383
421
|
res.append(self.cast(item, dst_type))
|
|
384
|
-
elif hasattr(item, "dtype") and item.dtype in
|
|
385
|
-
item.dtype != dst_type:
|
|
422
|
+
elif hasattr(item, "dtype") and item.dtype in \
|
|
423
|
+
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
386
424
|
res.append(self.cast(item, dst_type))
|
|
387
425
|
else:
|
|
388
426
|
res.append(item)
|
|
@@ -629,7 +667,10 @@ class Cell(Cell_):
|
|
|
629
667
|
if PackFunc.is_tracing():
|
|
630
668
|
return self._run_tracefunc(*args, **kwargs)
|
|
631
669
|
|
|
632
|
-
self.
|
|
670
|
+
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
|
|
671
|
+
self.check_names_and_refresh_name()
|
|
672
|
+
self._is_check_and_refresh = True
|
|
673
|
+
|
|
633
674
|
# Run in Graph mode.
|
|
634
675
|
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
635
676
|
self._check_construct_args(*args)
|
|
@@ -886,14 +927,14 @@ class Cell(Cell_):
|
|
|
886
927
|
>>> import mindspore as ms
|
|
887
928
|
>>> from mindspore import nn, Tensor
|
|
888
929
|
>>>
|
|
889
|
-
>>> class
|
|
930
|
+
>>> class ReluNet(nn.Cell):
|
|
890
931
|
... def __init__(self):
|
|
891
|
-
... super(
|
|
932
|
+
... super(ReluNet, self).__init__()
|
|
892
933
|
... self.relu = nn.ReLU()
|
|
893
934
|
... def construct(self, x):
|
|
894
935
|
... return self.relu(x)
|
|
895
936
|
>>>
|
|
896
|
-
>>> net =
|
|
937
|
+
>>> net = ReluNet()
|
|
897
938
|
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
898
939
|
>>> net.set_inputs(input_dyn)
|
|
899
940
|
>>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
|
|
@@ -902,13 +943,10 @@ class Cell(Cell_):
|
|
|
902
943
|
if self.grad_ops_label:
|
|
903
944
|
logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
|
|
904
945
|
f'generated.')
|
|
905
|
-
for ele in inputs:
|
|
906
|
-
if isinstance(ele, str):
|
|
907
|
-
raise TypeError(f"For element in 'set_inputs', the type must not be str.")
|
|
908
946
|
self._dynamic_shape_inputs = inputs
|
|
909
947
|
self._check_construct_args(*inputs)
|
|
910
948
|
if context._get_mode() == context.PYNATIVE_MODE:
|
|
911
|
-
_pynative_executor.set_dynamic_input(self)
|
|
949
|
+
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
912
950
|
|
|
913
951
|
def get_inputs(self):
|
|
914
952
|
"""
|
|
@@ -919,6 +957,26 @@ class Cell(Cell_):
|
|
|
919
957
|
|
|
920
958
|
.. warning::
|
|
921
959
|
This is an experimental API that is subject to change or deletion.
|
|
960
|
+
|
|
961
|
+
Examples:
|
|
962
|
+
>>> import numpy as np
|
|
963
|
+
>>> import mindspore as ms
|
|
964
|
+
>>> from mindspore import nn, Tensor
|
|
965
|
+
>>>
|
|
966
|
+
>>> class ReluNet(nn.Cell):
|
|
967
|
+
... def __init__(self):
|
|
968
|
+
... super(ReluNet, self).__init__()
|
|
969
|
+
... self.relu = nn.ReLU()
|
|
970
|
+
... def construct(self, x):
|
|
971
|
+
... return self.relu(x)
|
|
972
|
+
>>>
|
|
973
|
+
>>> net = ReluNet()
|
|
974
|
+
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
975
|
+
>>> net.set_inputs(input_dyn)
|
|
976
|
+
>>> get_inputs = net.get_inputs()
|
|
977
|
+
>>> print(get_inputs)
|
|
978
|
+
(Tensor(shape=[3, -1], dtype=Float32, value= ),)
|
|
979
|
+
|
|
922
980
|
"""
|
|
923
981
|
|
|
924
982
|
return self._dynamic_shape_inputs
|
|
@@ -936,9 +994,8 @@ class Cell(Cell_):
|
|
|
936
994
|
self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
|
|
937
995
|
|
|
938
996
|
if self._dynamic_shape_inputs is None:
|
|
939
|
-
compile_args = self.auto_identify_dynamic_shape.auto_dynamic_generate_compile_args(args)
|
|
940
997
|
_cell_graph_executor.compile(self, phase=self.phase,
|
|
941
|
-
jit_config_dict=self._jit_config_dict, *
|
|
998
|
+
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
942
999
|
else:
|
|
943
1000
|
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
944
1001
|
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
@@ -994,6 +1051,23 @@ class Cell(Cell_):
|
|
|
994
1051
|
Raises:
|
|
995
1052
|
KeyError: If the name of parameter is null or contains dot.
|
|
996
1053
|
TypeError: If the type of parameter is not Parameter.
|
|
1054
|
+
|
|
1055
|
+
Examples:
|
|
1056
|
+
>>> import mindspore as ms
|
|
1057
|
+
>>> from mindspore import Tensor, nn, Parameter
|
|
1058
|
+
...
|
|
1059
|
+
>>> class Net(nn.Cell):
|
|
1060
|
+
... def __init__(self):
|
|
1061
|
+
... super(Net, self).__init__()
|
|
1062
|
+
... self.relu = nn.ReLU()
|
|
1063
|
+
...
|
|
1064
|
+
... def construct(self, x):
|
|
1065
|
+
... x = self.relu(x)
|
|
1066
|
+
... return x
|
|
1067
|
+
>>> net = Net()
|
|
1068
|
+
>>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
|
|
1069
|
+
>>> print(net.bias)
|
|
1070
|
+
Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
|
|
997
1071
|
"""
|
|
998
1072
|
if not param_name:
|
|
999
1073
|
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
@@ -1048,6 +1122,18 @@ class Cell(Cell_):
|
|
|
1048
1122
|
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
|
1049
1123
|
TypeError: If type of `child_name` is not str.
|
|
1050
1124
|
TypeError: Child Cell's type is incorrect.
|
|
1125
|
+
|
|
1126
|
+
Examples:
|
|
1127
|
+
>>> import mindspore as ms
|
|
1128
|
+
>>> from mindspore import Tensor, nn
|
|
1129
|
+
...
|
|
1130
|
+
>>> net1 = nn.ReLU()
|
|
1131
|
+
>>> net2 = nn.Dense(2, 2)
|
|
1132
|
+
>>> net1.insert_child_to_cell("child", net2)
|
|
1133
|
+
>>> print(net1)
|
|
1134
|
+
ReLU<
|
|
1135
|
+
(child): Dense<input_channels=2, output_channels=2, has_bias=True>
|
|
1136
|
+
>
|
|
1051
1137
|
"""
|
|
1052
1138
|
if not isinstance(child_name, str):
|
|
1053
1139
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
@@ -1118,6 +1204,25 @@ class Cell(Cell_):
|
|
|
1118
1204
|
|
|
1119
1205
|
Returns:
|
|
1120
1206
|
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
|
|
1207
|
+
|
|
1208
|
+
Examples:
|
|
1209
|
+
>>> import mindspore as ms
|
|
1210
|
+
>>> from mindspore import Tensor, nn
|
|
1211
|
+
...
|
|
1212
|
+
>>> class Net(nn.Cell):
|
|
1213
|
+
... def __init__(self):
|
|
1214
|
+
... super(Net, self).__init__()
|
|
1215
|
+
... self.dense = nn.Dense(2, 2)
|
|
1216
|
+
...
|
|
1217
|
+
... def construct(self, x):
|
|
1218
|
+
... x = self.dense(x)
|
|
1219
|
+
... return x
|
|
1220
|
+
>>> net = Net()
|
|
1221
|
+
>>> print(net.init_parameters_data())
|
|
1222
|
+
{Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
|
|
1223
|
+
Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
|
|
1224
|
+
Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
|
|
1225
|
+
Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
|
|
1121
1226
|
"""
|
|
1122
1227
|
replace = dict()
|
|
1123
1228
|
|
|
@@ -1163,6 +1268,24 @@ class Cell(Cell_):
|
|
|
1163
1268
|
|
|
1164
1269
|
Returns:
|
|
1165
1270
|
OrderedDict, return parameters dictionary.
|
|
1271
|
+
|
|
1272
|
+
Examples:
|
|
1273
|
+
>>> import mindspore as ms
|
|
1274
|
+
>>> from mindspore import Tensor, nn, Parameter
|
|
1275
|
+
...
|
|
1276
|
+
>>> class Net(nn.Cell):
|
|
1277
|
+
... def __init__(self):
|
|
1278
|
+
... super(Net, self).__init__()
|
|
1279
|
+
... self.dense = nn.Dense(2, 2)
|
|
1280
|
+
...
|
|
1281
|
+
... def construct(self, x):
|
|
1282
|
+
... x = self.dense(x)
|
|
1283
|
+
... return x
|
|
1284
|
+
>>> net = Net()
|
|
1285
|
+
>>> print(net.parameters_dict())
|
|
1286
|
+
OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
|
|
1287
|
+
requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
|
|
1288
|
+
requires_grad=True))])
|
|
1166
1289
|
"""
|
|
1167
1290
|
param_dict = OrderedDict()
|
|
1168
1291
|
for param in self.get_parameters(expand=recurse):
|
|
@@ -1238,7 +1361,7 @@ class Cell(Cell_):
|
|
|
1238
1361
|
|
|
1239
1362
|
Tutorial Examples:
|
|
1240
1363
|
- `Model Training - Optimizer
|
|
1241
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1364
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#optimizer>`_
|
|
1242
1365
|
"""
|
|
1243
1366
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1244
1367
|
|
|
@@ -1263,6 +1386,7 @@ class Cell(Cell_):
|
|
|
1263
1386
|
Returns an iterator over cell parameters.
|
|
1264
1387
|
|
|
1265
1388
|
Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
|
|
1389
|
+
For more details about subcells, please see the example below.
|
|
1266
1390
|
|
|
1267
1391
|
Args:
|
|
1268
1392
|
expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
|
|
@@ -1272,11 +1396,34 @@ class Cell(Cell_):
|
|
|
1272
1396
|
Iteration, all parameters at the cell.
|
|
1273
1397
|
|
|
1274
1398
|
Examples:
|
|
1275
|
-
>>>
|
|
1276
|
-
>>>
|
|
1277
|
-
>>>
|
|
1278
|
-
>>>
|
|
1279
|
-
...
|
|
1399
|
+
>>> import mindspore as ms
|
|
1400
|
+
>>> from mindspore import nn, ops, Tensor
|
|
1401
|
+
>>> import numpy as np
|
|
1402
|
+
>>> class TestNet(nn.Cell):
|
|
1403
|
+
... def __init__(self):
|
|
1404
|
+
... super().__init__()
|
|
1405
|
+
... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
|
|
1406
|
+
... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
|
|
1407
|
+
... def construct(self, x):
|
|
1408
|
+
... x += self.my_w1
|
|
1409
|
+
... x = ops.reshape(x, (16,)) - self.my_w2
|
|
1410
|
+
... return x
|
|
1411
|
+
>>> class TestNet2(nn.Cell):
|
|
1412
|
+
... def __init__(self):
|
|
1413
|
+
... super().__init__()
|
|
1414
|
+
... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
|
|
1415
|
+
... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
|
|
1416
|
+
... # also be gathered.
|
|
1417
|
+
... self.subcell = TestNet()
|
|
1418
|
+
... def construct(self, x):
|
|
1419
|
+
... x += self.my_w1
|
|
1420
|
+
... x = ops.reshape(x, (16,)) - self.my_w2
|
|
1421
|
+
... return x
|
|
1422
|
+
>>> net = TestNet2()
|
|
1423
|
+
>>> print([p for p in net.get_parameters(expand=True)])
|
|
1424
|
+
[Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
|
|
1425
|
+
shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
|
|
1426
|
+
requires_grad=True)]
|
|
1280
1427
|
"""
|
|
1281
1428
|
for _, param in self.parameters_and_names(expand=expand):
|
|
1282
1429
|
yield param
|
|
@@ -1325,7 +1472,7 @@ class Cell(Cell_):
|
|
|
1325
1472
|
|
|
1326
1473
|
Tutorial Examples:
|
|
1327
1474
|
- `Building a Network - Model Parameters
|
|
1328
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1475
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/model.html#model-parameters>`_
|
|
1329
1476
|
"""
|
|
1330
1477
|
cells = []
|
|
1331
1478
|
if expand:
|
|
@@ -1337,7 +1484,7 @@ class Cell(Cell_):
|
|
|
1337
1484
|
for cell_name, cell in cells:
|
|
1338
1485
|
params = cell._params.items()
|
|
1339
1486
|
for par_name, par in params:
|
|
1340
|
-
if par.inited_param is not None:
|
|
1487
|
+
if par is not None and par.inited_param is not None:
|
|
1341
1488
|
par = par.inited_param
|
|
1342
1489
|
if par is not None and id(par) not in params_set:
|
|
1343
1490
|
params_set.add(id(par))
|
|
@@ -1394,6 +1541,22 @@ class Cell(Cell_):
|
|
|
1394
1541
|
|
|
1395
1542
|
Returns:
|
|
1396
1543
|
Iteration, the immediate cells in the cell.
|
|
1544
|
+
|
|
1545
|
+
Examples:
|
|
1546
|
+
>>> import mindspore as ms
|
|
1547
|
+
>>> from mindspore import Tensor, nn
|
|
1548
|
+
...
|
|
1549
|
+
>>> class Net(nn.Cell):
|
|
1550
|
+
... def __init__(self):
|
|
1551
|
+
... super(Net, self).__init__()
|
|
1552
|
+
... self.dense = nn.Dense(2, 2)
|
|
1553
|
+
...
|
|
1554
|
+
... def construct(self, x):
|
|
1555
|
+
... x = self.dense(x)
|
|
1556
|
+
... return x
|
|
1557
|
+
>>> net = Net()
|
|
1558
|
+
>>> print(net.cells())
|
|
1559
|
+
odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
|
|
1397
1560
|
"""
|
|
1398
1561
|
return self.name_cells().values()
|
|
1399
1562
|
|
|
@@ -1439,6 +1602,22 @@ class Cell(Cell_):
|
|
|
1439
1602
|
|
|
1440
1603
|
Returns:
|
|
1441
1604
|
Dict, all the child cells and corresponding names in the cell.
|
|
1605
|
+
|
|
1606
|
+
Examples:
|
|
1607
|
+
>>> import mindspore as ms
|
|
1608
|
+
>>> from mindspore import Tensor, nn
|
|
1609
|
+
...
|
|
1610
|
+
>>> class Net(nn.Cell):
|
|
1611
|
+
... def __init__(self):
|
|
1612
|
+
... super(Net, self).__init__()
|
|
1613
|
+
... self.dense = nn.Dense(2, 2)
|
|
1614
|
+
...
|
|
1615
|
+
... def construct(self, x):
|
|
1616
|
+
... x = self.dense(x)
|
|
1617
|
+
... return x
|
|
1618
|
+
>>> net = Net()
|
|
1619
|
+
>>> print(net.name_cells())
|
|
1620
|
+
OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
|
|
1442
1621
|
"""
|
|
1443
1622
|
value_set = set()
|
|
1444
1623
|
cells = OrderedDict()
|
|
@@ -1454,13 +1633,8 @@ class Cell(Cell_):
|
|
|
1454
1633
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
|
1455
1634
|
if "fp32" in flags and flags.get("fp32", False):
|
|
1456
1635
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
"""Add mixed precision flag to each cell"""
|
|
1460
|
-
if "fp16" in flags and flags.get("fp16", False):
|
|
1461
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
|
1462
|
-
if "fp32" in flags and flags.get("fp32", False):
|
|
1463
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1636
|
+
if "bf16" in flags and flags.get("bf16", False):
|
|
1637
|
+
Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
|
|
1464
1638
|
|
|
1465
1639
|
def apply(self, fn):
|
|
1466
1640
|
"""
|
|
@@ -1503,6 +1677,23 @@ class Cell(Cell_):
|
|
|
1503
1677
|
Args:
|
|
1504
1678
|
flags (dict): Network configuration information, currently it is used for the binding of network and
|
|
1505
1679
|
dataset. Users can also customize network attributes by this parameter.
|
|
1680
|
+
|
|
1681
|
+
Examples:
|
|
1682
|
+
>>> import mindspore as ms
|
|
1683
|
+
>>> from mindspore import Tensor, nn
|
|
1684
|
+
...
|
|
1685
|
+
>>> class Net(nn.Cell):
|
|
1686
|
+
... def __init__(self):
|
|
1687
|
+
... super(Net, self).__init__()
|
|
1688
|
+
... self.relu = nn.ReLU()
|
|
1689
|
+
...
|
|
1690
|
+
... def construct(self, x):
|
|
1691
|
+
... x = self.relu(x)
|
|
1692
|
+
... return x
|
|
1693
|
+
>>> net = Net()
|
|
1694
|
+
>>> net.add_flags(sink_mode=True)
|
|
1695
|
+
>>> print(net.sink_mode)
|
|
1696
|
+
True
|
|
1506
1697
|
"""
|
|
1507
1698
|
if not hasattr(self, "_func_graph_flags"):
|
|
1508
1699
|
self._func_graph_flags = {}
|
|
@@ -1518,9 +1709,25 @@ class Cell(Cell_):
|
|
|
1518
1709
|
Args:
|
|
1519
1710
|
flags (dict): Network configuration information, currently it is used for the binding of network and
|
|
1520
1711
|
dataset. Users can also customize network attributes by this parameter.
|
|
1712
|
+
|
|
1713
|
+
Examples:
|
|
1714
|
+
>>> import mindspore as ms
|
|
1715
|
+
>>> from mindspore import Tensor, nn
|
|
1716
|
+
...
|
|
1717
|
+
>>> class Net(nn.Cell):
|
|
1718
|
+
... def __init__(self):
|
|
1719
|
+
... super(Net, self).__init__()
|
|
1720
|
+
... self.relu = nn.ReLU()
|
|
1721
|
+
...
|
|
1722
|
+
... def construct(self, x):
|
|
1723
|
+
... x = self.relu(x)
|
|
1724
|
+
... return x
|
|
1725
|
+
>>> net = Net()
|
|
1726
|
+
>>> net.add_flags_recursive(sink_mode=True)
|
|
1727
|
+
>>> print(net.sink_mode)
|
|
1728
|
+
True
|
|
1521
1729
|
"""
|
|
1522
1730
|
self.add_flags(**flags)
|
|
1523
|
-
self._add_mixed_precision_flag_recursive(**flags)
|
|
1524
1731
|
for cell in self.cells():
|
|
1525
1732
|
cell.add_flags_recursive(**flags)
|
|
1526
1733
|
return self
|
|
@@ -1532,17 +1739,28 @@ class Cell(Cell_):
|
|
|
1532
1739
|
def get_flags(self):
|
|
1533
1740
|
"""
|
|
1534
1741
|
Get the self_defined attributes of the cell, which can be added by `add_flags` method.
|
|
1742
|
+
|
|
1743
|
+
Examples:
|
|
1744
|
+
>>> import mindspore as ms
|
|
1745
|
+
>>> from mindspore import Tensor, nn
|
|
1746
|
+
...
|
|
1747
|
+
>>> class Net(nn.Cell):
|
|
1748
|
+
... def __init__(self):
|
|
1749
|
+
... super(Net, self).__init__()
|
|
1750
|
+
... self.relu = nn.ReLU()
|
|
1751
|
+
...
|
|
1752
|
+
... def construct(self, x):
|
|
1753
|
+
... x = self.relu(x)
|
|
1754
|
+
... return x
|
|
1755
|
+
>>> net = Net()
|
|
1756
|
+
>>> net.add_flags(sink_mode=True)
|
|
1757
|
+
>>> print(net.get_flags())
|
|
1758
|
+
{'sink_mode':True}
|
|
1535
1759
|
"""
|
|
1536
1760
|
if not hasattr(self, "_func_graph_flags"):
|
|
1537
1761
|
self._func_graph_flags = {}
|
|
1538
1762
|
return self._func_graph_flags
|
|
1539
1763
|
|
|
1540
|
-
def _set_mixed_precision_type_recursive(self, mixed_type):
|
|
1541
|
-
"""Set mixed precision type to each cell"""
|
|
1542
|
-
Cell_.set_mixed_precision_type(self, mixed_type)
|
|
1543
|
-
for cell in self.cells():
|
|
1544
|
-
cell._set_mixed_precision_type_recursive(mixed_type)
|
|
1545
|
-
|
|
1546
1764
|
def to_float(self, dst_type):
|
|
1547
1765
|
"""
|
|
1548
1766
|
Add cast on all inputs of cell and child cells to run with certain float type.
|
|
@@ -1555,13 +1773,13 @@ class Cell(Cell_):
|
|
|
1555
1773
|
|
|
1556
1774
|
Args:
|
|
1557
1775
|
dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
|
|
1558
|
-
dst_type can be `mstype.float16` or `mstype.
|
|
1776
|
+
dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
|
|
1559
1777
|
|
|
1560
1778
|
Returns:
|
|
1561
1779
|
Cell, the cell itself.
|
|
1562
1780
|
|
|
1563
1781
|
Raises:
|
|
1564
|
-
ValueError: If dst_type is not mstype.float32
|
|
1782
|
+
ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
|
|
1565
1783
|
|
|
1566
1784
|
Supported Platforms:
|
|
1567
1785
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1573,19 +1791,15 @@ class Cell(Cell_):
|
|
|
1573
1791
|
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
|
1574
1792
|
>>> net.to_float(mstype.float16)
|
|
1575
1793
|
Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
|
|
1576
|
-
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=
|
|
1577
|
-
"""
|
|
1578
|
-
if dst_type not in (mstype.float16, mstype.float32):
|
|
1579
|
-
raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32
|
|
1580
|
-
"but got type: {} and value: {}.".format(type(dst_type), dst_type))
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
self.to_float_fp16 = True
|
|
1584
|
-
else:
|
|
1585
|
-
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
|
1586
|
-
self.to_float_fp16 = False
|
|
1587
|
-
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
|
|
1794
|
+
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
|
|
1795
|
+
"""
|
|
1796
|
+
if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
|
|
1797
|
+
raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
|
|
1798
|
+
"mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
|
|
1799
|
+
flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
|
|
1800
|
+
'bf16': dst_type == mstype.bfloat16}
|
|
1588
1801
|
self._add_init_args(**flags)
|
|
1802
|
+
self.add_flags_recursive(**flags)
|
|
1589
1803
|
return self
|
|
1590
1804
|
|
|
1591
1805
|
def set_boost(self, boost_type):
|
|
@@ -1594,7 +1808,7 @@ class Cell(Cell_):
|
|
|
1594
1808
|
accelerate the algorithm in the algorithm library.
|
|
1595
1809
|
|
|
1596
1810
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1597
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.
|
|
1811
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.2/mindspore/python/mindspore/boost>`_.
|
|
1598
1812
|
|
|
1599
1813
|
Note:
|
|
1600
1814
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1651,12 +1865,12 @@ class Cell(Cell_):
|
|
|
1651
1865
|
|
|
1652
1866
|
Tutorial Examples:
|
|
1653
1867
|
- `Model Training - Implementing Training and Evaluation
|
|
1654
|
-
<https://mindspore.cn/tutorials/en/r2.
|
|
1868
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#training-and-evaluation>`_
|
|
1655
1869
|
"""
|
|
1656
|
-
if mode
|
|
1657
|
-
self._phase = 'predict'
|
|
1658
|
-
else:
|
|
1870
|
+
if mode:
|
|
1659
1871
|
self._phase = 'train'
|
|
1872
|
+
else:
|
|
1873
|
+
self._phase = 'predict'
|
|
1660
1874
|
self.add_flags_recursive(training=mode)
|
|
1661
1875
|
return self
|
|
1662
1876
|
|
|
@@ -1685,16 +1899,27 @@ class Cell(Cell_):
|
|
|
1685
1899
|
|
|
1686
1900
|
Args:
|
|
1687
1901
|
jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
|
|
1902
|
+
|
|
1903
|
+
Examples:
|
|
1904
|
+
>>> import mindspore as ms
|
|
1905
|
+
>>> from mindspore import Tensor, nn
|
|
1906
|
+
...
|
|
1907
|
+
>>> class Net(nn.Cell):
|
|
1908
|
+
... def __init__(self):
|
|
1909
|
+
... super(Net, self).__init__()
|
|
1910
|
+
... self.relu = nn.ReLU()
|
|
1911
|
+
...
|
|
1912
|
+
... def construct(self, x):
|
|
1913
|
+
... x = self.relu(x)
|
|
1914
|
+
... return x
|
|
1915
|
+
>>> net = Net()
|
|
1916
|
+
>>> jitconfig = ms.JitConfig()
|
|
1917
|
+
>>> net.set_jit_config(jitconfig)
|
|
1688
1918
|
"""
|
|
1689
1919
|
if self._jit_config_dict:
|
|
1690
1920
|
logger.warning("For Cell, jit config can only be set once, ignore this setting.")
|
|
1691
1921
|
else:
|
|
1692
1922
|
self._jit_config_dict = jit_config.jit_config_dict
|
|
1693
|
-
enable_ge = os.getenv("MS_ENABLE_GE") == '1'
|
|
1694
|
-
enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
|
|
1695
|
-
if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
|
|
1696
|
-
raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
|
|
1697
|
-
format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
|
|
1698
1923
|
|
|
1699
1924
|
def flatten_weights(self, fusion_size=0):
|
|
1700
1925
|
"""
|
|
@@ -2290,12 +2515,13 @@ class Cell(Cell_):
|
|
|
2290
2515
|
def _run_tracefunc(self, *args, **kwargs):
|
|
2291
2516
|
""" Run Packed Cell in Pack."""
|
|
2292
2517
|
args = self._mixed_precision_cast(args)
|
|
2293
|
-
|
|
2518
|
+
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
|
2519
|
+
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
|
2294
2520
|
expander = PackExpander.get_instance()
|
|
2295
2521
|
args = expander.begin_subgraph(self, *args)
|
|
2296
2522
|
args = [_convert_tensor(a) for a in args]
|
|
2297
2523
|
output = self._run_construct(args, kwargs)
|
|
2298
|
-
ret = expander.end_subgraph(output)
|
|
2524
|
+
ret = expander.end_subgraph(self, output)
|
|
2299
2525
|
output = _convert_tensor(ret)
|
|
2300
2526
|
else:
|
|
2301
2527
|
with _SetMixedPrecision(self):
|
|
@@ -2306,10 +2532,23 @@ class Cell(Cell_):
|
|
|
2306
2532
|
mixed_type = self.get_mixed_precision_type()
|
|
2307
2533
|
if mixed_type == MixedPrecisionType.NOTSET:
|
|
2308
2534
|
return inputs
|
|
2309
|
-
|
|
2535
|
+
if mixed_type == MixedPrecisionType.FP16:
|
|
2536
|
+
cast_type = mstype.float16
|
|
2537
|
+
elif mixed_type == MixedPrecisionType.BF16:
|
|
2538
|
+
cast_type = mstype.bfloat16
|
|
2539
|
+
else:
|
|
2540
|
+
cast_type = mstype.float32
|
|
2310
2541
|
cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
|
|
2311
2542
|
return cast_inputs
|
|
2312
2543
|
|
|
2544
|
+
def _get_attr_from_cell(self, network):
|
|
2545
|
+
if not isinstance(network, Cell):
|
|
2546
|
+
return
|
|
2547
|
+
if hasattr(network, "jit_config_dict"):
|
|
2548
|
+
self._jit_config_dict = network.jit_config_dict
|
|
2549
|
+
if hasattr(network, "_amp_level"):
|
|
2550
|
+
self._amp_level = getattr(network, "_amp_level")
|
|
2551
|
+
|
|
2313
2552
|
|
|
2314
2553
|
class GraphCell(Cell):
|
|
2315
2554
|
"""
|