mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
|
@@ -15,24 +15,25 @@
|
|
|
15
15
|
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
|
|
16
16
|
import ast
|
|
17
17
|
from mindspore import log as logger
|
|
18
|
-
from
|
|
19
|
-
from
|
|
18
|
+
from .parser_register import ParserRegister, reg_parser
|
|
19
|
+
from .parser import Parser
|
|
20
20
|
from ..symbol_tree import SymbolTree
|
|
21
21
|
from ..api.node_type import NodeType
|
|
22
|
+
from ..node.node_manager import NodeManager
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class FunctionDefParser(Parser):
|
|
25
|
-
"""Parse bodies of ast.FunctionDef
|
|
26
|
+
"""Parse bodies of ast.FunctionDef in SymbolTree."""
|
|
26
27
|
|
|
27
28
|
def target(self):
|
|
28
29
|
"""Parse target type"""
|
|
29
30
|
return ast.FunctionDef
|
|
30
31
|
|
|
31
|
-
def remove_dead_code(self,
|
|
32
|
+
def remove_dead_code(self, node_manager: NodeManager):
|
|
32
33
|
"""Remove dead codes"""
|
|
33
34
|
# Find out return node position
|
|
34
35
|
return_idx = -1
|
|
35
|
-
for idx, node in enumerate(
|
|
36
|
+
for idx, node in enumerate(node_manager.nodes()):
|
|
36
37
|
if node.get_node_type() == NodeType.Output:
|
|
37
38
|
return_idx = idx
|
|
38
39
|
break
|
|
@@ -40,29 +41,36 @@ class FunctionDefParser(Parser):
|
|
|
40
41
|
return
|
|
41
42
|
# Remove nodes after return node.
|
|
42
43
|
# Reverse traversal to ensure that nodes are orphaned and can be deleted.
|
|
43
|
-
for idx, node in reversed(list(enumerate(
|
|
44
|
+
for idx, node in reversed(list(enumerate(node_manager.nodes()))):
|
|
44
45
|
if idx <= return_idx:
|
|
45
46
|
break
|
|
46
47
|
logger.info(f"Remove dead code node:{node.get_name()}")
|
|
47
|
-
|
|
48
|
+
node_manager.erase_node(node)
|
|
48
49
|
|
|
49
|
-
def process(self, stree: SymbolTree,
|
|
50
|
-
"""
|
|
51
|
-
|
|
50
|
+
def process(self, stree: SymbolTree, ast_node: ast.FunctionDef, node_manager: NodeManager):
|
|
51
|
+
"""
|
|
52
|
+
Parse bodies of ast.FunctionDef in SymbolTree.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
stree (SymbolTree): symbol tree under parsing.
|
|
56
|
+
ast_node (ast.FunctionDef): Ast FunctionDef node in construct.
|
|
57
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
58
|
+
"""
|
|
52
59
|
# parse args as inputs of stree
|
|
53
|
-
arguments: ast.arguments =
|
|
60
|
+
arguments: ast.arguments = ast_node.args
|
|
54
61
|
parser: Parser = ParserRegister.instance().get_parser(ast.arguments)
|
|
55
|
-
parser.process(stree, arguments)
|
|
62
|
+
parser.process(stree, arguments, node_manager)
|
|
56
63
|
|
|
57
64
|
# parse body as node of stree
|
|
58
|
-
for body in
|
|
65
|
+
for body in ast_node.body:
|
|
59
66
|
# avoid add dead code, so we need to break if return is added.
|
|
60
67
|
parser: Parser = ParserRegister.instance().get_parser(type(body))
|
|
61
68
|
if parser is None:
|
|
62
|
-
stree.append_python_node(
|
|
69
|
+
stree.append_python_node(ast_node, body, node_manager)
|
|
63
70
|
else:
|
|
64
|
-
parser.process(stree, body)
|
|
65
|
-
|
|
71
|
+
parser.process(stree, body, node_manager)
|
|
72
|
+
|
|
73
|
+
self.remove_dead_code(node_manager)
|
|
66
74
|
|
|
67
75
|
|
|
68
76
|
g_functiondef_parser = reg_parser(FunctionDefParser())
|
|
@@ -15,11 +15,12 @@
|
|
|
15
15
|
"""Parse ast.If in construct function to node of SymbolTree."""
|
|
16
16
|
|
|
17
17
|
import ast
|
|
18
|
-
import astunparse
|
|
19
18
|
|
|
20
19
|
from ..symbol_tree import SymbolTree
|
|
21
|
-
from
|
|
22
|
-
from
|
|
20
|
+
from .parser import Parser
|
|
21
|
+
from .parser_register import ParserRegister, reg_parser
|
|
22
|
+
from ..node import NodeManager, ControlFlow
|
|
23
|
+
from ..ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class IfParser(Parser):
|
|
@@ -29,35 +30,38 @@ class IfParser(Parser):
|
|
|
29
30
|
"""Parse target type"""
|
|
30
31
|
return ast.If
|
|
31
32
|
|
|
32
|
-
def process(self, stree: SymbolTree, node: ast.If):
|
|
33
|
+
def process(self, stree: SymbolTree, node: ast.If, node_manager: NodeManager):
|
|
33
34
|
"""
|
|
34
|
-
Parse ast.If and create
|
|
35
|
+
Parse ast.If and create nodes into symbol tree.
|
|
35
36
|
|
|
36
37
|
Args:
|
|
37
38
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
38
39
|
node ([ast.If]): An ast.If node.
|
|
40
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
39
41
|
|
|
40
42
|
Raises:
|
|
41
43
|
NotImplementedError: If test of ast.If can not be eval.
|
|
42
44
|
"""
|
|
45
|
+
# expand codes in ast.if
|
|
46
|
+
ast_if = FlattenRecursiveStmt().transform_if(node, stree)
|
|
47
|
+
# parse ast codes of if branch into ControlFlow Node
|
|
48
|
+
if_node = ControlFlow("if_node", ast_if.body, stree)
|
|
49
|
+
for body in ast_if.body:
|
|
50
|
+
parser: Parser = ParserRegister.instance().get_parser(type(body))
|
|
51
|
+
if parser is None:
|
|
52
|
+
stree.append_python_node(ast_if, body, node_manager=if_node)
|
|
53
|
+
else:
|
|
54
|
+
parser.process(stree, body, node_manager=if_node)
|
|
55
|
+
stree.append_origin_field(if_node, node_manager)
|
|
56
|
+
# parse ast codes of else branch into ControlFlow Node
|
|
57
|
+
if ast_if.orelse:
|
|
58
|
+
else_node = ControlFlow("else_node", ast_if.orelse, stree)
|
|
59
|
+
for body in ast_if.orelse:
|
|
60
|
+
parser: Parser = ParserRegister.instance().get_parser(type(body))
|
|
61
|
+
if parser is None:
|
|
62
|
+
stree.append_python_node(ast_if, body, node_manager=else_node)
|
|
63
|
+
else:
|
|
64
|
+
parser.process(stree, body, node_manager=else_node)
|
|
65
|
+
stree.append_origin_field(else_node, node_manager)
|
|
43
66
|
|
|
44
|
-
test_code = astunparse.unparse(node.test)
|
|
45
|
-
test_code = test_code.replace("self", "stree.get_origin_network()")
|
|
46
|
-
bodies = None
|
|
47
|
-
try:
|
|
48
|
-
test_value = eval(test_code)
|
|
49
|
-
except (NameError, TypeError):
|
|
50
|
-
stree.try_append_python_node(node, node)
|
|
51
|
-
return
|
|
52
|
-
|
|
53
|
-
bodies = node.body if test_value else node.orelse
|
|
54
|
-
index = stree.get_ast_root().body.index(node) + 1
|
|
55
|
-
info_node = ast.Name(id=f"# If node has been replaced by {bool(test_value)} branch.",
|
|
56
|
-
lineno=0, col_offset=0, ctx=ast.Load)
|
|
57
|
-
exp_node = ast.Expr(value=info_node, lineno=0, col_offset=0, ctx=ast.Load)
|
|
58
|
-
stree.get_ast_root().body.insert(index-1, exp_node)
|
|
59
|
-
for body in bodies:
|
|
60
|
-
stree.get_ast_root().body.insert(index, body)
|
|
61
|
-
index += 1
|
|
62
|
-
stree.get_ast_root().body.remove(node)
|
|
63
67
|
g_if_parser = reg_parser(IfParser())
|
|
@@ -13,23 +13,32 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.Module to SymbolTrees."""
|
|
16
|
+
import sys
|
|
16
17
|
from typing import Any
|
|
17
18
|
import os
|
|
18
19
|
import ast
|
|
19
20
|
import copy
|
|
20
21
|
import inspect
|
|
21
|
-
import astunparse
|
|
22
22
|
|
|
23
23
|
from mindspore import log as logger
|
|
24
24
|
from ..symbol_tree import SymbolTree
|
|
25
|
-
from
|
|
26
|
-
from
|
|
25
|
+
from .parser import Parser
|
|
26
|
+
from .parser_register import ParserRegister, reg_parser
|
|
27
27
|
from ..ast_helpers import AstFinder
|
|
28
28
|
from ..common import error_str
|
|
29
|
+
from ..node.node_manager import NodeManager
|
|
29
30
|
|
|
31
|
+
if sys.version_info >= (3, 9):
|
|
32
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
33
|
+
else:
|
|
34
|
+
import astunparse
|
|
30
35
|
|
|
31
36
|
class ModuleParser(Parser):
|
|
32
37
|
"""Parse ast.Module to SymbolTrees."""
|
|
38
|
+
|
|
39
|
+
# a denied_class_decorator_list represents the decorators should be banned, which is registered by user
|
|
40
|
+
denied_class_decorator_list = []
|
|
41
|
+
|
|
33
42
|
@staticmethod
|
|
34
43
|
def _find_class(ast_node: ast.Module) -> ast.ClassDef:
|
|
35
44
|
"""Find all ast.ClassDef in ast.Module, only support one ast.ClassDef in ast.Module now."""
|
|
@@ -45,18 +54,27 @@ class ModuleParser(Parser):
|
|
|
45
54
|
def _get_import_node(ast_root):
|
|
46
55
|
"""Iterate over ast_root and return all ast.Import nodes or ast.ImportFrom nodes in ast_root."""
|
|
47
56
|
import_nodes = []
|
|
57
|
+
try_nodes = []
|
|
58
|
+
imports_str = []
|
|
48
59
|
|
|
49
60
|
class GetImportNode(ast.NodeVisitor):
|
|
50
61
|
"""Find all import nodes from input ast node."""
|
|
51
62
|
|
|
63
|
+
def visit_Try(self, node: ast.Try) -> Any:
|
|
64
|
+
if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
|
|
65
|
+
try_nodes.append(copy.deepcopy(node))
|
|
66
|
+
return node
|
|
67
|
+
|
|
52
68
|
def visit_Import(self, node: ast.Import) -> Any:
|
|
53
69
|
"""Iterate over all nodes and save ast.Import nodes."""
|
|
54
70
|
import_nodes.append(copy.deepcopy(node))
|
|
71
|
+
imports_str.append(astunparse.unparse(node))
|
|
55
72
|
return node
|
|
56
73
|
|
|
57
74
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
|
58
75
|
"""Iterate over all nodes and save ast.ImportFrom nodes."""
|
|
59
76
|
import_nodes.append(copy.deepcopy(node))
|
|
77
|
+
imports_str.append(astunparse.unparse(node))
|
|
60
78
|
return node
|
|
61
79
|
|
|
62
80
|
def get_node(self, input_ast):
|
|
@@ -64,19 +82,145 @@ class ModuleParser(Parser):
|
|
|
64
82
|
self.generic_visit(input_ast)
|
|
65
83
|
return True
|
|
66
84
|
|
|
85
|
+
def _remove_duplicated_import_in_try(node: [ast.Import, ast.ImportFrom]):
|
|
86
|
+
import_str = astunparse.unparse(node)
|
|
87
|
+
if import_str in imports_str:
|
|
88
|
+
import_nodes.remove(import_nodes[imports_str.index(import_str)])
|
|
89
|
+
|
|
67
90
|
get_node_handler = GetImportNode()
|
|
68
91
|
get_node_handler.get_node(ast_root)
|
|
92
|
+
for Try in try_nodes:
|
|
93
|
+
for body in Try.body:
|
|
94
|
+
_remove_duplicated_import_in_try(body)
|
|
95
|
+
for handler in Try.handlers:
|
|
96
|
+
for body in handler.body:
|
|
97
|
+
_remove_duplicated_import_in_try(body)
|
|
98
|
+
import_nodes.extend(try_nodes)
|
|
69
99
|
return import_nodes
|
|
70
100
|
|
|
71
101
|
@staticmethod
|
|
72
|
-
def
|
|
102
|
+
def save_file_path_to_sys(stree, level_num, file_path):
|
|
103
|
+
"""
|
|
104
|
+
Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom.
|
|
105
|
+
|
|
106
|
+
When level_num = 0(e.g. from xxx import yyy), current path will be saved.
|
|
107
|
+
When level_num = 1(e.g. from .xxx import yyy), current path will be saved.
|
|
108
|
+
When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved.
|
|
109
|
+
"""
|
|
110
|
+
file_path = os.path.dirname(os.path.abspath(file_path))
|
|
111
|
+
if level_num > 1:
|
|
112
|
+
for _ in range(level_num - 1):
|
|
113
|
+
file_path = os.path.dirname(file_path)
|
|
114
|
+
sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0]
|
|
115
|
+
stree.get_import_asts().append(ast.Import([ast.alias(name='sys', asname=None)]))
|
|
116
|
+
stree.get_import_asts().append(sys_path_append_ast)
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def _save_imports(stree):
|
|
73
120
|
"""Insert two groups of import nodes to ast.Module, common ones and those from class definition file."""
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
121
|
+
stree.get_import_asts().append(ast.Import([ast.alias(name='mindspore', asname=None)]))
|
|
122
|
+
stree.get_import_asts().append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)],
|
|
123
|
+
level=0))
|
|
124
|
+
stree.get_import_asts().append(ast.ImportFrom(module='mindspore.nn',
|
|
125
|
+
names=[ast.alias(name='Cell', asname=None)], level=0))
|
|
126
|
+
stree.get_import_asts().append(ast.ImportFrom(module='mindspore.ops',
|
|
127
|
+
names=[ast.alias(name='functional', asname='F')], level=0))
|
|
128
|
+
origin_net = stree.get_origin_network()
|
|
129
|
+
net_path = inspect.getfile(type(origin_net))
|
|
130
|
+
ModuleParser.save_file_path_to_sys(stree, 0, net_path)
|
|
131
|
+
ModuleParser.save_imports_from_file(stree, net_path)
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def get_valid_import_info(import_node, file_path):
|
|
135
|
+
"""Get valid import info while import_node.module is at form of relative path"""
|
|
136
|
+
# copy to a new node to avoid origin import_node being modified.
|
|
137
|
+
import_node_test = copy.deepcopy(import_node)
|
|
138
|
+
file_path = os.path.dirname(os.path.abspath(file_path))
|
|
139
|
+
# get real path from import_node.level
|
|
140
|
+
# from .(A) import xxx: current path
|
|
141
|
+
# from ..(A) import xxx: last level path
|
|
142
|
+
import_node_module_name = import_node.module
|
|
143
|
+
level = import_node.level
|
|
144
|
+
# from A import xxx: it does not need to pad, directly return the module name
|
|
145
|
+
if level == 0:
|
|
146
|
+
return import_node_module_name, None
|
|
147
|
+
if level > 1:
|
|
148
|
+
for _ in range(level - 1):
|
|
149
|
+
file_path = os.path.dirname(file_path)
|
|
150
|
+
file_path_tmp = file_path[:]
|
|
151
|
+
max_level_count = file_path.count('/') + file_path.count('\\') - 1
|
|
152
|
+
level_count = 0
|
|
153
|
+
# suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx'
|
|
154
|
+
suffix = ''
|
|
155
|
+
if import_node_module_name:
|
|
156
|
+
suffix = '.' + import_node_module_name
|
|
157
|
+
while level_count < max_level_count:
|
|
158
|
+
file_path_tmp = os.path.dirname(file_path_tmp)
|
|
159
|
+
import_node_test.module = file_path[len(file_path_tmp) + 1:].replace('/', '.') + suffix
|
|
160
|
+
import_node_test.level = 0
|
|
161
|
+
import_code = astunparse.unparse(import_node_test).strip()
|
|
162
|
+
test_code = f"import sys\nsys.path.insert(0, r'{file_path_tmp}')\n{import_code}"
|
|
163
|
+
try:
|
|
164
|
+
exec(test_code) # pylint: disable=W0122
|
|
165
|
+
except (ValueError, ImportError) as e:
|
|
166
|
+
# try upper level to avoid ValueError: attempted relative import beyond top-level package
|
|
167
|
+
# this exception is changed to ImportError after python3.9
|
|
168
|
+
logger.info(f"For MindSpore Rewrite, in module parser, test import code: "
|
|
169
|
+
f"{import_code} failed: {e}. Try upper level.")
|
|
170
|
+
level_count += 1
|
|
171
|
+
continue
|
|
172
|
+
except Exception as e: # pylint: disable=W0703
|
|
173
|
+
logger.warning(f"For MindSpore Rewrite, in module parser, process import code: "
|
|
174
|
+
f"{import_code} failed: {e}. Ignore this import code.")
|
|
175
|
+
return None, None
|
|
176
|
+
else:
|
|
177
|
+
# try test code success
|
|
178
|
+
return import_node_test.module, file_path_tmp
|
|
179
|
+
# try codes with all level failed
|
|
180
|
+
logger.warning(f"For MindSpore Rewrite, in module parser, test import code: "
|
|
181
|
+
f"{astunparse.unparse(import_node).strip()} failed. Ignore this import code.")
|
|
182
|
+
return None, None
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def save_imports_from_file(stree, file_path):
|
|
186
|
+
"""Save imports from file"""
|
|
187
|
+
if not os.path.exists(file_path):
|
|
188
|
+
raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.")
|
|
189
|
+
try:
|
|
190
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
191
|
+
source_code = f.read()
|
|
192
|
+
import_nodes = ModuleParser._get_import_node(ast.parse(source_code))
|
|
193
|
+
except RuntimeError as err:
|
|
194
|
+
raise RuntimeError(f"For MindSpore Rewrite, in module parser, get import nodes error: {err}")
|
|
195
|
+
if not import_nodes:
|
|
196
|
+
return
|
|
197
|
+
for import_node in import_nodes:
|
|
198
|
+
import_node = ModuleParser._process_relative_import(stree, import_node, file_path)
|
|
199
|
+
if import_node:
|
|
200
|
+
stree.get_import_asts().append(import_node)
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def _process_relative_import(stree, import_node, file_path):
|
|
204
|
+
"""Process relative imports"""
|
|
205
|
+
if isinstance(import_node, ast.ImportFrom):
|
|
206
|
+
# pad the ImportFrom with parent path
|
|
207
|
+
# e.g. from ..C import xxx -> from A.B.C import xxx
|
|
208
|
+
import_module, import_path = ModuleParser.get_valid_import_info(import_node, file_path)
|
|
209
|
+
if import_path:
|
|
210
|
+
ModuleParser.save_file_path_to_sys(stree, 0, import_path)
|
|
211
|
+
module_name_list = [alias.name.strip() for alias in import_node.names]
|
|
212
|
+
# add the module into _imported_modules to direct the class
|
|
213
|
+
stree.save_imported_modules(file_path, import_module, module_name_list)
|
|
214
|
+
import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0)
|
|
215
|
+
elif isinstance(import_node, ast.Import):
|
|
216
|
+
for alias in import_node.names:
|
|
217
|
+
name = alias.name
|
|
218
|
+
stree.save_imported_modules(file_path, name.strip(), [])
|
|
219
|
+
return import_node
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def _add_decorator_to_class(class_ast: ast.ClassDef, origin_net):
|
|
223
|
+
"""Add decorators to class"""
|
|
80
224
|
origin_net_source_code_file = inspect.getfile(type(origin_net))
|
|
81
225
|
if not os.path.exists(origin_net_source_code_file):
|
|
82
226
|
raise RuntimeError("For MindSpore Rewrite, in module parser, File ", origin_net_source_code_file,
|
|
@@ -84,35 +228,62 @@ class ModuleParser(Parser):
|
|
|
84
228
|
try:
|
|
85
229
|
with open(origin_net_source_code_file, "r", encoding="utf-8") as f:
|
|
86
230
|
source_code = f.read()
|
|
87
|
-
|
|
231
|
+
decorators = ModuleParser._get_decorator(ast.parse(source_code), origin_net)
|
|
88
232
|
except RuntimeError:
|
|
89
|
-
raise RuntimeError("For MindSpore Rewrite, in module parser, get
|
|
90
|
-
if
|
|
91
|
-
for
|
|
92
|
-
|
|
93
|
-
ast.fix_missing_locations(
|
|
233
|
+
raise RuntimeError("For MindSpore Rewrite, in module parser, get decorators error")
|
|
234
|
+
if decorators:
|
|
235
|
+
for decorator_index, decorator_node in enumerate(decorators):
|
|
236
|
+
class_ast.decorator_list.insert(decorator_index, decorator_node)
|
|
237
|
+
ast.fix_missing_locations(class_ast)
|
|
94
238
|
|
|
95
239
|
@staticmethod
|
|
96
|
-
def
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
240
|
+
def _get_decorator(ast_root, origin_net):
|
|
241
|
+
"""Get the decorators of function"""
|
|
242
|
+
net_name = type(origin_net).__name__
|
|
243
|
+
decorators = []
|
|
244
|
+
|
|
245
|
+
class GetClassNode(ast.NodeVisitor):
|
|
246
|
+
"""Find the class node from input ast node."""
|
|
247
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
|
248
|
+
"""Visit the class node and add the decorators to class node"""
|
|
249
|
+
if node.name == net_name:
|
|
250
|
+
for decorator in node.decorator_list[:]:
|
|
251
|
+
decorator_name = ""
|
|
252
|
+
if isinstance(decorator, ast.Call):
|
|
253
|
+
func = decorator.func
|
|
254
|
+
if isinstance(func, ast.Name):
|
|
255
|
+
decorator_name = func.id
|
|
256
|
+
elif isinstance(decorator, ast.Name):
|
|
257
|
+
decorator_name = decorator.id
|
|
258
|
+
# User should set the denied class_decorator,
|
|
259
|
+
# because the symbol_tree cant pass the correct parameters to decorators but the instance "obj".
|
|
260
|
+
if decorator_name not in ModuleParser.denied_class_decorator_list:
|
|
261
|
+
decorators.append(decorator)
|
|
262
|
+
return node
|
|
263
|
+
|
|
264
|
+
def get_node(self, input_ast):
|
|
265
|
+
"""Interface of GetClassNode."""
|
|
266
|
+
self.generic_visit(input_ast)
|
|
267
|
+
return True
|
|
268
|
+
|
|
269
|
+
get_node_handler = GetClassNode()
|
|
270
|
+
get_node_handler.get_node(ast_root)
|
|
271
|
+
return decorators
|
|
101
272
|
|
|
102
273
|
def target(self):
|
|
103
274
|
"""Parse target type"""
|
|
104
275
|
return ast.Module
|
|
105
276
|
|
|
106
|
-
def process(self, stree: SymbolTree, node: ast.Module):
|
|
277
|
+
def process(self, stree: SymbolTree, node: ast.Module, node_manager: NodeManager):
|
|
107
278
|
"""Process ast.ClassDef nodes in ast.Module."""
|
|
108
|
-
ModuleParser.
|
|
279
|
+
ModuleParser._save_imports(stree)
|
|
109
280
|
class_ast = ModuleParser._find_class(node)
|
|
281
|
+
ModuleParser._add_decorator_to_class(class_ast, stree.get_origin_network())
|
|
110
282
|
stree.set_class_ast(class_ast)
|
|
111
|
-
ModuleParser._save_net_file_path(stree)
|
|
112
283
|
for body in node.body:
|
|
113
284
|
if isinstance(body, ast.ClassDef):
|
|
114
285
|
parser: Parser = ParserRegister.instance().get_parser(ast.ClassDef)
|
|
115
|
-
parser.process(stree, body)
|
|
286
|
+
parser.process(stree, body, stree)
|
|
116
287
|
else:
|
|
117
288
|
logger.info(f"For MindSpore Rewrite, in module parser, Ignoring unsupported "
|
|
118
289
|
f"node({astunparse.unparse(body)}) in ast.Module.")
|
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
import abc
|
|
17
17
|
import ast
|
|
18
18
|
|
|
19
|
-
from
|
|
19
|
+
from ..symbol_tree import SymbolTree
|
|
20
|
+
from ..node.node_manager import NodeManager
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class Parser(abc.ABC):
|
|
@@ -34,12 +35,13 @@ class Parser(abc.ABC):
|
|
|
34
35
|
return type(None)
|
|
35
36
|
|
|
36
37
|
@abc.abstractmethod
|
|
37
|
-
def process(self, stree: SymbolTree, node: ast.AST):
|
|
38
|
+
def process(self, stree: SymbolTree, node: ast.AST, node_manager: NodeManager):
|
|
38
39
|
"""
|
|
39
40
|
Parse input ast node and add parse result into SymbolTree.
|
|
40
41
|
|
|
41
42
|
Args:
|
|
42
43
|
stree (SymbolTree): current symbol_tree
|
|
43
44
|
node (ast.AST): node who is tried to be parsed
|
|
45
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
44
46
|
"""
|
|
45
47
|
raise NotImplementedError
|
|
@@ -45,7 +45,7 @@ class ParserRegister:
|
|
|
45
45
|
parser (Parser): An instance of Parser to be registered.
|
|
46
46
|
"""
|
|
47
47
|
if isinstance(parser, Parser):
|
|
48
|
-
ParserRegister.instance().
|
|
48
|
+
ParserRegister.instance().get_parsers()[parser.target()] = parser
|
|
49
49
|
|
|
50
50
|
def get_parser(self, ast_type: type) -> Optional[Parser]:
|
|
51
51
|
"""
|
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
import ast
|
|
17
17
|
|
|
18
18
|
from ..symbol_tree import SymbolTree
|
|
19
|
-
from ..node import Node
|
|
20
|
-
from ..
|
|
21
|
-
from
|
|
19
|
+
from ..node.node import Node
|
|
20
|
+
from ..node.node_manager import NodeManager
|
|
21
|
+
from .parser import Parser
|
|
22
|
+
from .parser_register import reg_parser
|
|
22
23
|
from ..common import error_str
|
|
23
24
|
|
|
24
25
|
|
|
@@ -29,14 +30,13 @@ class ReturnParser(Parser):
|
|
|
29
30
|
"""Parse target type"""
|
|
30
31
|
return ast.Return
|
|
31
32
|
|
|
32
|
-
def process(self, stree: SymbolTree, node: ast.Return):
|
|
33
|
+
def process(self, stree: SymbolTree, node: ast.Return, node_manager: NodeManager):
|
|
33
34
|
"""Parse ast.Return to output-node of SymbolTree."""
|
|
34
35
|
return_value = node.value
|
|
35
36
|
if not isinstance(return_value, ast.Name):
|
|
36
37
|
raise RuntimeError(error_str(f"only support ast.Name as return value, but got ast type "
|
|
37
38
|
f"'{type(return_value).__name__}'", father_node=node, child_node=return_value))
|
|
38
39
|
node_return = Node.create_output_node(node, [return_value.id])
|
|
39
|
-
stree.append_origin_field(node_return)
|
|
40
|
-
|
|
40
|
+
stree.append_origin_field(node_return, node_manager)
|
|
41
41
|
|
|
42
42
|
g_return_parser = reg_parser(ReturnParser())
|
|
@@ -13,11 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Sparsify transformer"""
|
|
16
|
+
import sys
|
|
16
17
|
import ast
|
|
17
18
|
import inspect
|
|
18
19
|
import textwrap
|
|
19
20
|
from collections import deque
|
|
20
|
-
import astunparse
|
|
21
21
|
|
|
22
22
|
from mindspore import ops, nn
|
|
23
23
|
from mindspore import log as logger
|
|
@@ -25,6 +25,10 @@ from mindspore.rewrite.parsers.assign_parser import AssignParser
|
|
|
25
25
|
from mindspore.rewrite.sparsify.utils import ArgType, SparseFunc, sparse_rules, get_sparse_func, builtin_ops, \
|
|
26
26
|
get_binop_name, get_sparse_method_outputs, arg_type_to_prefix_map, get_inputs_outputs
|
|
27
27
|
|
|
28
|
+
if sys.version_info >= (3, 9):
|
|
29
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
30
|
+
else:
|
|
31
|
+
import astunparse
|
|
28
32
|
|
|
29
33
|
OPS_MODULE = "mindspore.ops."
|
|
30
34
|
MAX_RECURSION_DEPTH = 10
|
|
@@ -61,8 +65,13 @@ def sparsify_helper(f, arg_types, user_defined_rules=None, sparse_name="", full_
|
|
|
61
65
|
|
|
62
66
|
if changed:
|
|
63
67
|
sparse_tree = list(x[0] for x in sparse_transformer.sparse_functiondef.values()) + sparse_tree
|
|
64
|
-
|
|
65
|
-
|
|
68
|
+
if sys.version_info >= (3, 9):
|
|
69
|
+
ast_module = ast.Module([ast.FunctionDef(
|
|
70
|
+
sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)],
|
|
71
|
+
type_ignores=[])
|
|
72
|
+
else:
|
|
73
|
+
ast_module = ast.Module([ast.FunctionDef(
|
|
74
|
+
sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)])
|
|
66
75
|
return ast_module, True, return_types
|
|
67
76
|
return tree, False, return_types
|
|
68
77
|
|