mindspore 2.1.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/rewrite/api/node.py
CHANGED
|
@@ -14,12 +14,13 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Rewrite module api: Node."""
|
|
16
16
|
|
|
17
|
-
from typing import Union, Optional
|
|
17
|
+
from typing import Union, Optional, List, Dict
|
|
18
|
+
from types import FunctionType
|
|
18
19
|
|
|
19
20
|
from mindspore.nn import Cell
|
|
20
21
|
from mindspore.ops.primitive import Primitive
|
|
21
22
|
from mindspore import _checkparam as Validator
|
|
22
|
-
from ..node import Node as NodeImpl
|
|
23
|
+
from ..node.node import Node as NodeImpl
|
|
23
24
|
from ..symbol_tree import SymbolTree as SymbolTreeImpl
|
|
24
25
|
from .node_type import NodeType
|
|
25
26
|
from .scoped_value import ScopedValue
|
|
@@ -50,8 +51,8 @@ class Node:
|
|
|
50
51
|
return self._node == other._node
|
|
51
52
|
|
|
52
53
|
@staticmethod
|
|
53
|
-
def create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
|
54
|
-
kwargs:
|
|
54
|
+
def create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None,
|
|
55
|
+
kwargs: Dict[str, ScopedValue] = None, name: str = "", is_sub_net: bool = False) -> 'Node':
|
|
55
56
|
"""
|
|
56
57
|
Create a node. Only support create from a `Cell` now.
|
|
57
58
|
|
|
@@ -63,14 +64,15 @@ class Node:
|
|
|
63
64
|
|
|
64
65
|
Args:
|
|
65
66
|
cell (Cell): Cell-operator of this forward-layer.
|
|
66
|
-
targets (
|
|
67
|
-
|
|
67
|
+
targets (List[Union[ScopedValue, str]]): Indicate output names. Used as targets of an assign statement in
|
|
68
|
+
source code.
|
|
69
|
+
args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
68
70
|
source code. Default: ``None`` , which indicates the `cell` has no args inputs.
|
|
69
|
-
kwargs (
|
|
71
|
+
kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
70
72
|
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
71
73
|
code. Default: ``None`` , which indicates the `cell` has no kwargs inputs.
|
|
72
74
|
name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will
|
|
73
|
-
generate name from `
|
|
75
|
+
generate name from `cell` when name is None. Rewrite will check and ensure the uniqueness of `name`
|
|
74
76
|
while node being inserted. Default: ``""`` .
|
|
75
77
|
is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse
|
|
76
78
|
the `cell` to a TreeNode, otherwise the `cell` is parsed to a CallCell node. Default: ``False`` .
|
|
@@ -89,7 +91,7 @@ class Node:
|
|
|
89
91
|
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
90
92
|
>>> import mindspore.nn as nn
|
|
91
93
|
>>> # Define the network structure of LeNet5. Refer to
|
|
92
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
94
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
93
95
|
>>> net = LeNet5()
|
|
94
96
|
>>> stree = SymbolTree.create(net)
|
|
95
97
|
>>> node = stree.get_node("conv1")
|
|
@@ -108,8 +110,66 @@ class Node:
|
|
|
108
110
|
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
109
111
|
if kwargs is not None:
|
|
110
112
|
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
111
|
-
return Node(NodeImpl.create_call_op(cell, None, targets,
|
|
112
|
-
|
|
113
|
+
return Node(NodeImpl.create_call_op(cell, None, targets, args, kwargs, name, is_sub_net))
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]],
|
|
117
|
+
args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None) -> 'Node':
|
|
118
|
+
"""
|
|
119
|
+
Create a node that corresponds to a function call. The `function` object is saved into network, and used via
|
|
120
|
+
getting object from `self.` .
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
function (FunctionType): The function to be called.
|
|
124
|
+
targets (List[Union[ScopedValue, str]]): indicates output names. Used as targets of an assign statement in
|
|
125
|
+
source code.
|
|
126
|
+
args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
127
|
+
source code. Default: ``None`` , which indicates the `function` has no args inputs.
|
|
128
|
+
kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
129
|
+
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
130
|
+
code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
An instance of `Node`.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
TypeError: If `function` is not a `FunctionType`.
|
|
137
|
+
TypeError: If `targets` is not `list`.
|
|
138
|
+
TypeError: If the type of `targets` is not in `[ScopedValue, str]`.
|
|
139
|
+
TypeError: If arg in `args` is not a `ScopedValue`.
|
|
140
|
+
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`.
|
|
141
|
+
|
|
142
|
+
Examples:
|
|
143
|
+
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
144
|
+
>>> import mindspore.nn as nn
|
|
145
|
+
>>> import mindspore.ops as ops
|
|
146
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
147
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
148
|
+
>>> net = LeNet5()
|
|
149
|
+
>>> stree = SymbolTree.create(net)
|
|
150
|
+
>>> node = stree.get_node("conv1")
|
|
151
|
+
>>> position = stree.after(node)
|
|
152
|
+
>>> new_node = node.create_call_function(function=ops.abs, targets=['x'],
|
|
153
|
+
... args=[ScopedValue.create_naming_value('x')])
|
|
154
|
+
>>> stree.insert(position, new_node)
|
|
155
|
+
>>> print(new_node.get_node_type())
|
|
156
|
+
NodeType.CallFunction
|
|
157
|
+
"""
|
|
158
|
+
Validator.check_value_type("function", function, [FunctionType, type], "create_call_function")
|
|
159
|
+
Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "create_call_function")
|
|
160
|
+
if args is not None:
|
|
161
|
+
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "create_call_function")
|
|
162
|
+
if kwargs is not None:
|
|
163
|
+
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "create_call_function")
|
|
164
|
+
return Node(NodeImpl._create_call_function(function, targets, args, kwargs))
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def create_input(param_name: str, default: Optional[ScopedValue] = None) -> 'Node':
|
|
168
|
+
# pylint: disable=missing-function-docstring
|
|
169
|
+
Validator.check_value_type("param_name", param_name, [str], "Node")
|
|
170
|
+
if default is not None:
|
|
171
|
+
Validator.check_value_type("default", default, [ScopedValue], "Node")
|
|
172
|
+
return Node(NodeImpl.create_input_node(None, param_name, default, name=f"input_{param_name}"))
|
|
113
173
|
|
|
114
174
|
def get_handler(self) -> NodeImpl:
|
|
115
175
|
return self._node
|
|
@@ -124,7 +184,7 @@ class Node:
|
|
|
124
184
|
Examples:
|
|
125
185
|
>>> from mindspore.rewrite import SymbolTree
|
|
126
186
|
>>> # Define the network structure of LeNet5. Refer to
|
|
127
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
187
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
128
188
|
>>> net = LeNet5()
|
|
129
189
|
>>> stree = SymbolTree.create(net)
|
|
130
190
|
>>> node = stree.get_node("conv2")
|
|
@@ -144,7 +204,7 @@ class Node:
|
|
|
144
204
|
Examples:
|
|
145
205
|
>>> from mindspore.rewrite import SymbolTree
|
|
146
206
|
>>> # Define the network structure of LeNet5. Refer to
|
|
147
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
207
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
148
208
|
>>> net = LeNet5()
|
|
149
209
|
>>> stree = SymbolTree.create(net)
|
|
150
210
|
>>> node = stree.get_node("conv1")
|
|
@@ -177,7 +237,7 @@ class Node:
|
|
|
177
237
|
Examples:
|
|
178
238
|
>>> from mindspore.rewrite import SymbolTree
|
|
179
239
|
>>> # Define the network structure of LeNet5. Refer to
|
|
180
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
240
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
181
241
|
>>> net = LeNet5()
|
|
182
242
|
>>> stree = SymbolTree.create(net)
|
|
183
243
|
>>> node = stree.get_node("relu_3")
|
|
@@ -216,7 +276,7 @@ class Node:
|
|
|
216
276
|
Examples:
|
|
217
277
|
>>> from mindspore.rewrite import SymbolTree
|
|
218
278
|
>>> # Define the network structure of LeNet5. Refer to
|
|
219
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
279
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
220
280
|
>>> net = LeNet5()
|
|
221
281
|
>>> stree = SymbolTree.create(net)
|
|
222
282
|
>>> src_node = stree.get_node("fc1")
|
|
@@ -256,7 +316,7 @@ class Node:
|
|
|
256
316
|
Examples:
|
|
257
317
|
>>> from mindspore.rewrite import SymbolTree
|
|
258
318
|
>>> # Define the network structure of LeNet5. Refer to
|
|
259
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
319
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
260
320
|
>>> net = LeNet5()
|
|
261
321
|
>>> stree = SymbolTree.create(net)
|
|
262
322
|
>>> node = stree.get_node("conv1")
|
|
@@ -276,7 +336,7 @@ class Node:
|
|
|
276
336
|
Examples:
|
|
277
337
|
>>> from mindspore.rewrite import SymbolTree
|
|
278
338
|
>>> # Define the network structure of LeNet5. Refer to
|
|
279
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
339
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
280
340
|
>>> net = LeNet5()
|
|
281
341
|
>>> stree = SymbolTree.create(net)
|
|
282
342
|
>>> node = stree.get_node("conv1")
|
|
@@ -303,7 +363,7 @@ class Node:
|
|
|
303
363
|
Examples:
|
|
304
364
|
>>> from mindspore.rewrite import SymbolTree
|
|
305
365
|
>>> # Define the network structure of LeNet5. Refer to
|
|
306
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
366
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
307
367
|
>>> net = LeNet5()
|
|
308
368
|
>>> stree = SymbolTree.create(net)
|
|
309
369
|
>>> node = stree.get_node("conv1")
|
|
@@ -326,7 +386,7 @@ class Node:
|
|
|
326
386
|
Examples:
|
|
327
387
|
>>> from mindspore.rewrite import SymbolTree
|
|
328
388
|
>>> # Define the network structure of LeNet5. Refer to
|
|
329
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
389
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
330
390
|
>>> net = LeNet5()
|
|
331
391
|
>>> stree = SymbolTree.create(net)
|
|
332
392
|
>>> node = stree.get_node("conv1")
|
|
@@ -335,6 +395,29 @@ class Node:
|
|
|
335
395
|
"""
|
|
336
396
|
return self._node.get_args()
|
|
337
397
|
|
|
398
|
+
def get_symbol_tree(self) -> 'SymbolTree':
|
|
399
|
+
"""
|
|
400
|
+
Get the symbol tree which current node belongs to.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
SymbolTree, None if current node does not belong to any SymbolTree.
|
|
404
|
+
|
|
405
|
+
Examples:
|
|
406
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
407
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
408
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
409
|
+
>>> net = LeNet5()
|
|
410
|
+
>>> stree = SymbolTree.create(net)
|
|
411
|
+
>>> node = stree.get_node("conv1")
|
|
412
|
+
>>> print(type(node.get_symbol_tree()))
|
|
413
|
+
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
|
|
414
|
+
"""
|
|
415
|
+
from .symbol_tree import SymbolTree
|
|
416
|
+
stree_impl = self._node.get_belong_symbol_tree()
|
|
417
|
+
if not stree_impl:
|
|
418
|
+
return None
|
|
419
|
+
return SymbolTree(stree_impl)
|
|
420
|
+
|
|
338
421
|
def get_kwargs(self) -> {str: ScopedValue}:
|
|
339
422
|
return self._node.get_kwargs()
|
|
340
423
|
|
|
@@ -23,14 +23,17 @@ class NodeType(Enum):
|
|
|
23
23
|
- Unknown: Not inited NodeType.
|
|
24
24
|
- CallCell: `CallCell` node represents invoking cell-op in forward method.
|
|
25
25
|
- CallPrimitive: `CallPrimitive` node represents invoking primitive-op in forward method.
|
|
26
|
-
- CallFunction: `CallFunction` node represents invoking
|
|
26
|
+
- CallFunction: `CallFunction` node represents invoking a function in forward method.
|
|
27
27
|
- CallMethod: `CallMethod` node represents invoking of method in forward method which can not be mapped to
|
|
28
28
|
cell-op or primitive-op in MindSpore.
|
|
29
29
|
- Python: `Python` node holds unsupported-ast-node or unnecessary-to-parse-ast-node.
|
|
30
30
|
- Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
|
|
31
31
|
- Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
|
|
32
32
|
- Tree: `Tree` node represents sub-network invoking in forward method.
|
|
33
|
+
- CellContainer: `CellContainer` node represents invoking method :class:`mindspore.nn.SequentialCell` in
|
|
34
|
+
forward method.
|
|
33
35
|
- MathOps: `MathOps` node represents a mathematical operation, such as adding or comparing in forward method.
|
|
36
|
+
- ControlFlow: `ControlFlow` node represents a control flow statement, such as if statement.
|
|
34
37
|
|
|
35
38
|
"""
|
|
36
39
|
Unknown = 0
|
|
@@ -47,3 +50,4 @@ class NodeType(Enum):
|
|
|
47
50
|
Tree = 9
|
|
48
51
|
CellContainer = 10
|
|
49
52
|
MathOps = 11
|
|
53
|
+
ControlFlow = 12
|
|
@@ -364,7 +364,7 @@ class PatternEngine:
|
|
|
364
364
|
continue
|
|
365
365
|
if cur_node.get_node_type() == NodeType.Tree:
|
|
366
366
|
subtree = TreeNodeHelper.get_sub_tree(cur_node)
|
|
367
|
-
self.apply(subtree)
|
|
367
|
+
_ = self.apply(subtree)
|
|
368
368
|
visited.append(cur_node)
|
|
369
369
|
queue.extend(cur_node.get_users())
|
|
370
370
|
continue
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Rewrite module api: ValueType and ScopedValue."""
|
|
16
16
|
from enum import Enum
|
|
17
|
-
from typing import Optional, Union
|
|
17
|
+
from typing import Optional, Union, List, Tuple
|
|
18
18
|
from mindspore import _checkparam as Validator
|
|
19
19
|
|
|
20
20
|
|
|
@@ -28,10 +28,7 @@ class ValueType(Enum):
|
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
# base type
|
|
31
|
-
|
|
32
|
-
IntValue = 1
|
|
33
|
-
FloatValue = 2
|
|
34
|
-
NoneValue = 3
|
|
31
|
+
ConstantValue = 0
|
|
35
32
|
# container type
|
|
36
33
|
TupleValue = 20
|
|
37
34
|
ListValue = 21
|
|
@@ -82,14 +79,8 @@ class ScopedValue:
|
|
|
82
79
|
>>> print(variable)
|
|
83
80
|
2
|
|
84
81
|
"""
|
|
85
|
-
if value
|
|
86
|
-
return cls(ValueType.
|
|
87
|
-
if isinstance(value, int):
|
|
88
|
-
return cls(ValueType.IntValue, "", value)
|
|
89
|
-
if isinstance(value, float):
|
|
90
|
-
return cls(ValueType.FloatValue, "", value)
|
|
91
|
-
if isinstance(value, str):
|
|
92
|
-
return cls(ValueType.StringValue, "", value)
|
|
82
|
+
if isinstance(value, (type(None), int, float, str, bool)):
|
|
83
|
+
return cls(ValueType.ConstantValue, "", value)
|
|
93
84
|
if isinstance(value, tuple):
|
|
94
85
|
return cls(ValueType.TupleValue, "",
|
|
95
86
|
tuple(cls.create_variable_value(single_value) for single_value in value))
|
|
@@ -130,13 +121,14 @@ class ScopedValue:
|
|
|
130
121
|
return cls(ValueType.NamingValue, scope, name)
|
|
131
122
|
|
|
132
123
|
@staticmethod
|
|
133
|
-
def create_name_values(names: Union[
|
|
124
|
+
def create_name_values(names: Union[List[str], Tuple[str]],
|
|
125
|
+
scopes: Union[List[str], Tuple[str]] = None) -> List['ScopedValue']:
|
|
134
126
|
"""
|
|
135
127
|
Create a list of naming `ScopedValue`.
|
|
136
128
|
|
|
137
129
|
Args:
|
|
138
|
-
names (
|
|
139
|
-
scopes (
|
|
130
|
+
names (List[str] or Tuple[str]): List or tuple of `str` represents names of referenced variables.
|
|
131
|
+
scopes (List[str] or Tuple[str]): List or tuple of `str` represents scopes of referenced variables.
|
|
140
132
|
Default: ``None`` .
|
|
141
133
|
|
|
142
134
|
Returns:
|
|
@@ -168,7 +160,7 @@ class ScopedValue:
|
|
|
168
160
|
return result
|
|
169
161
|
|
|
170
162
|
def __str__(self):
|
|
171
|
-
if self.type
|
|
163
|
+
if self.type == ValueType.ConstantValue:
|
|
172
164
|
return str(self.value)
|
|
173
165
|
if self.type == ValueType.NamingValue:
|
|
174
166
|
return f"{self.scope}.{self.value}" if self.scope else str(self.value)
|