mindspore 2.1.0__cp37-cp37m-win_amd64.whl → 2.2.11__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -16,16 +16,18 @@
|
|
|
16
16
|
from typing import Optional, Union
|
|
17
17
|
import ast
|
|
18
18
|
import inspect
|
|
19
|
+
from types import FunctionType
|
|
19
20
|
|
|
20
21
|
from mindspore.nn import Cell
|
|
21
22
|
from mindspore.ops import Primitive
|
|
22
23
|
from mindspore import log as logger
|
|
23
|
-
from
|
|
24
|
-
from
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
from
|
|
28
|
-
from
|
|
24
|
+
from ... import _checkparam as Validator
|
|
25
|
+
from ..ast_helpers import AstModifier
|
|
26
|
+
from ..api.scoped_value import ScopedValue, ValueType
|
|
27
|
+
from ..api.node_type import NodeType
|
|
28
|
+
from ..namespace import is_subtree
|
|
29
|
+
from ..ast_helpers.ast_replacer import AstReplacer
|
|
30
|
+
from ..ast_creator_register import ast_creator_registry
|
|
29
31
|
|
|
30
32
|
PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough")
|
|
31
33
|
|
|
@@ -36,35 +38,33 @@ class Node:
|
|
|
36
38
|
invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of
|
|
37
39
|
Node has different meaning in different type of node:
|
|
38
40
|
|
|
39
|
-
- CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
|
|
40
|
-
is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
|
|
41
|
-
corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
|
|
42
|
-
corresponding to func of call expression which means symbol of the cell-op.
|
|
41
|
+
- CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
|
|
42
|
+
`targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
|
|
43
|
+
`kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
|
|
44
|
+
method. `func` is corresponding to func of call expression which means symbol of the cell-op.
|
|
43
45
|
- CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore.
|
|
44
|
-
`targets`, `args`, `kwargs` and `
|
|
46
|
+
`targets`, `args`, `kwargs` and `func_name` are as previous.
|
|
45
47
|
- CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`.
|
|
46
|
-
`targets` is corresponding to targets of ast.Assign which means return values of this method. `
|
|
47
|
-
the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
|
|
48
|
-
value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
|
|
49
|
-
CallMethod node whose `
|
|
50
|
-
- GetAttr: retrieves a parameter from the SymbolTree hierarchy. `func` represents which parameter in SymbolTree
|
|
51
|
-
hierarchy. `targets` is corresponding to targets of ast.Assign which means what symbol to accept the result of
|
|
52
|
-
get-attr. `args` and `kwargs` are don't-care.
|
|
48
|
+
`targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name`
|
|
49
|
+
represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
|
|
50
|
+
method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
|
|
51
|
+
mapped to CallMethod node whose `func_name` is "PassThrough".
|
|
53
52
|
- Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not
|
|
54
|
-
supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `
|
|
53
|
+
supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care.
|
|
55
54
|
- Input: an input node represents an input of current network which also a parameter of forward method of Cell.
|
|
56
55
|
`targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter
|
|
57
|
-
of forward function. `kwargs` and `
|
|
56
|
+
of forward function. `kwargs` and `func_name` are don't-care.
|
|
58
57
|
- Output: an output node represents the output of current network which is corresponding to return statement of
|
|
59
|
-
forward method of Cell. `args` represents return values. `
|
|
60
|
-
don't-care.
|
|
58
|
+
forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and
|
|
59
|
+
`kwargs` are don't-care.
|
|
61
60
|
- Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so
|
|
62
|
-
`targets`, `args`, `kwargs` and `
|
|
63
|
-
instance.
|
|
61
|
+
`targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a
|
|
62
|
+
SymbolTree instance.
|
|
64
63
|
"""
|
|
65
64
|
|
|
66
65
|
def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
67
|
-
|
|
66
|
+
func_name: Optional[ScopedValue], args: [ScopedValue], kwargs: {str: ScopedValue}, name: str,
|
|
67
|
+
instance):
|
|
68
68
|
"""
|
|
69
69
|
Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
|
|
70
70
|
as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and
|
|
@@ -75,7 +75,7 @@ class Node:
|
|
|
75
75
|
ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should
|
|
76
76
|
not be None except when node type is Unknown.
|
|
77
77
|
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
78
|
-
|
|
78
|
+
func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
|
|
79
79
|
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
80
80
|
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
81
81
|
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
@@ -89,7 +89,7 @@ class Node:
|
|
|
89
89
|
self._attribute = Node._get_cell_or_prim_op_attribute(instance)
|
|
90
90
|
self._instance = instance
|
|
91
91
|
self._name = name
|
|
92
|
-
self.
|
|
92
|
+
self._func_name: Optional[ScopedValue] = func_name
|
|
93
93
|
self._targets: [ScopedValue] = targets
|
|
94
94
|
self._args_num = len(args) if args is not None else 0
|
|
95
95
|
self._kwargs_num = len(kwargs) if kwargs is not None else 0
|
|
@@ -101,48 +101,17 @@ class Node:
|
|
|
101
101
|
self._next: Optional[Node] = None
|
|
102
102
|
# A handler of SymbolTree current node belonging to
|
|
103
103
|
self._belong_tree = None
|
|
104
|
-
# A
|
|
104
|
+
# A handler of NodeManager current node belonging to
|
|
105
|
+
self._node_manager = None
|
|
106
|
+
# A dict that records which target of which Node current Node's argument come from
|
|
105
107
|
self._arg_providers: {int: (Node, int)} = {}
|
|
106
108
|
# A dict that records which argument of which Node uses current Node's target
|
|
107
109
|
self._target_users: {int: [(Node, int)]} = {}
|
|
108
110
|
|
|
109
|
-
@classmethod
|
|
110
|
-
def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
111
|
-
func: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
|
|
112
|
-
name: str = ""):
|
|
113
|
-
"""
|
|
114
|
-
Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
115
|
-
A `CallCell` node represents an invoking to cell-op.
|
|
116
|
-
A `CallPrimitive` node represents an invoking to primitive-op.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
120
|
-
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
121
|
-
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
122
|
-
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
123
|
-
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
124
|
-
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
125
|
-
class.
|
|
126
|
-
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
127
|
-
Name of node also used as field name in network class.
|
|
128
|
-
"""
|
|
129
|
-
|
|
130
|
-
if not isinstance(op, (Cell, Primitive)):
|
|
131
|
-
raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
|
|
132
|
-
non_custom_args = Node._handle_custom_obj_in_args(args)
|
|
133
|
-
non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
|
|
134
|
-
if ast_node is None:
|
|
135
|
-
ast_node = AstModifier.create_call_assign(targets, func, non_custom_args, non_custom_kwargs)
|
|
136
|
-
if isinstance(op, Cell):
|
|
137
|
-
node_type = NodeType.CallCell
|
|
138
|
-
else:
|
|
139
|
-
node_type = NodeType.CallPrimitive
|
|
140
|
-
return cls(node_type, ast_node, targets, func, args, kwargs, name, op)
|
|
141
|
-
|
|
142
111
|
@classmethod
|
|
143
112
|
def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
144
|
-
|
|
145
|
-
name: str = ""):
|
|
113
|
+
func_name: Union[ScopedValue, str], args: [ScopedValue] = None,
|
|
114
|
+
kwargs: {str: ScopedValue}=None, name: str = ""):
|
|
146
115
|
"""
|
|
147
116
|
Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an
|
|
148
117
|
invoking to cell-op.
|
|
@@ -151,7 +120,7 @@ class Node:
|
|
|
151
120
|
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node`
|
|
152
121
|
should not be None currently.
|
|
153
122
|
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
154
|
-
|
|
123
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
155
124
|
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
156
125
|
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
157
126
|
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
@@ -161,12 +130,12 @@ class Node:
|
|
|
161
130
|
args = []
|
|
162
131
|
if kwargs is None:
|
|
163
132
|
kwargs = {}
|
|
164
|
-
if isinstance(
|
|
165
|
-
|
|
133
|
+
if isinstance(func_name, str):
|
|
134
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
166
135
|
new_targets = Node._handle_targets(targets)
|
|
167
136
|
if ast_node is None:
|
|
168
137
|
raise RuntimeError("Input ast_node is None")
|
|
169
|
-
return cls(NodeType.CallMethod, ast_node, new_targets,
|
|
138
|
+
return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
|
|
170
139
|
|
|
171
140
|
@classmethod
|
|
172
141
|
def create_call_pass_through_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
@@ -189,7 +158,8 @@ class Node:
|
|
|
189
158
|
return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance)
|
|
190
159
|
|
|
191
160
|
@classmethod
|
|
192
|
-
def create_input_node(cls, ast_node: ast.AST, arg_name: str, default: Optional[ScopedValue] = None,
|
|
161
|
+
def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None,
|
|
162
|
+
name: str = ""):
|
|
193
163
|
"""
|
|
194
164
|
Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of
|
|
195
165
|
SymbolTree which is corresponding to parameters of forward function.
|
|
@@ -206,6 +176,8 @@ class Node:
|
|
|
206
176
|
args = []
|
|
207
177
|
else:
|
|
208
178
|
args = [default]
|
|
179
|
+
if ast_node is None:
|
|
180
|
+
ast_node = ast.arg(arg_name)
|
|
209
181
|
return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
|
|
210
182
|
|
|
211
183
|
@classmethod
|
|
@@ -243,17 +215,83 @@ class Node:
|
|
|
243
215
|
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
|
|
244
216
|
sequentially in the list.
|
|
245
217
|
ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
|
|
246
|
-
|
|
218
|
+
saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
|
|
247
219
|
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
248
220
|
Name of node also used as field name in network class. The format of mathops node name
|
|
249
221
|
is 'AstNodeName_AstOpName_n'.
|
|
250
222
|
"""
|
|
251
223
|
return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
|
|
252
224
|
|
|
225
|
+
@staticmethod
|
|
226
|
+
def create_assign_node(targets, func_name, args, kwargs):
|
|
227
|
+
"""Create a ast.Assign type node."""
|
|
228
|
+
# create targets
|
|
229
|
+
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
230
|
+
# create call
|
|
231
|
+
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
232
|
+
ast_args = ast_creator_registry.get("Args")(args)
|
|
233
|
+
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
234
|
+
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
235
|
+
# create assign
|
|
236
|
+
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
237
|
+
return ast_node
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
|
241
|
+
kwargs: {str: ScopedValue}=None):
|
|
242
|
+
"""
|
|
243
|
+
Create a node that corresponds to a function call.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
function (FunctionType): The function to be called.
|
|
247
|
+
targets (list[str]): indicates output names. Used as targets of an assign statement in source code.
|
|
248
|
+
args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
249
|
+
source code. Default: ``None`` , which indicates the `function` has no args inputs.
|
|
250
|
+
kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
251
|
+
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
252
|
+
code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
An instance of `Node`.
|
|
256
|
+
"""
|
|
257
|
+
if args is None:
|
|
258
|
+
args = []
|
|
259
|
+
if kwargs is None:
|
|
260
|
+
kwargs = {}
|
|
261
|
+
targets = Node._handle_targets(targets)
|
|
262
|
+
_package = None
|
|
263
|
+
if isinstance(function, FunctionType):
|
|
264
|
+
_package = function.__globals__['__package__']
|
|
265
|
+
func_full_name = ".".join([_package, function.__name__]) if _package else function.__name__
|
|
266
|
+
func_scope = ''
|
|
267
|
+
func_name = func_full_name.split('.')[-1]
|
|
268
|
+
if func_full_name.count('.') > 0:
|
|
269
|
+
func_scope = func_full_name.rsplit('.')[0]
|
|
270
|
+
func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
|
|
271
|
+
node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
|
|
272
|
+
return node
|
|
273
|
+
|
|
274
|
+
@classmethod
|
|
275
|
+
def inner_create_call_function(cls, node_name, ast_node, func_name, function, targets, args, kwargs):
|
|
276
|
+
'''
|
|
277
|
+
Instantiate an instance of node whose type is `CallFunction`.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
node_name (str): Name of node.
|
|
281
|
+
func_name (str): Name of function.
|
|
282
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
283
|
+
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
284
|
+
function (Object): An instance of function. See detail in docstring of Node class.
|
|
285
|
+
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
286
|
+
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
287
|
+
class.
|
|
288
|
+
'''
|
|
289
|
+
return cls(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, function)
|
|
290
|
+
|
|
253
291
|
@staticmethod
|
|
254
292
|
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
255
|
-
|
|
256
|
-
|
|
293
|
+
args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "",
|
|
294
|
+
is_sub_net: bool = False):
|
|
257
295
|
"""
|
|
258
296
|
Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
259
297
|
If op is custom defined, it is treated by TreeNode.
|
|
@@ -264,12 +302,11 @@ class Node:
|
|
|
264
302
|
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
265
303
|
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
266
304
|
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
267
|
-
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
268
305
|
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
269
306
|
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
270
307
|
class.
|
|
271
|
-
|
|
272
|
-
Name of node also used as field name in network class.
|
|
308
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
309
|
+
`SymbolTree`. Name of node also used as field name in network class.
|
|
273
310
|
is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the
|
|
274
311
|
`cell` to a TreeNode, else a CallCell Node. Default is a False.
|
|
275
312
|
"""
|
|
@@ -277,29 +314,58 @@ class Node:
|
|
|
277
314
|
if ast_node is not None:
|
|
278
315
|
Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node")
|
|
279
316
|
Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
|
|
280
|
-
Validator.check_value_type("func", func, [ScopedValue, str], "Node")
|
|
281
317
|
if args is not None:
|
|
282
318
|
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
283
319
|
if kwargs is not None:
|
|
284
320
|
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
285
|
-
cls_name = type(op).__name__
|
|
286
|
-
|
|
287
321
|
if args is None:
|
|
288
322
|
args = []
|
|
289
323
|
if kwargs is None:
|
|
290
324
|
kwargs = {}
|
|
291
|
-
|
|
292
|
-
func = ScopedValue.create_naming_value(func)
|
|
325
|
+
Validator.check_value_type("node_name", node_name, [str], "Node")
|
|
293
326
|
new_targets = Node._handle_targets(targets)
|
|
294
|
-
if
|
|
295
|
-
|
|
327
|
+
if isinstance(node_name, str):
|
|
328
|
+
func_name = ScopedValue.create_naming_value(node_name)
|
|
329
|
+
else:
|
|
330
|
+
func_name = node_name
|
|
331
|
+
if is_sub_net and is_subtree(op):
|
|
332
|
+
from ..symbol_tree_builder import SymbolTreeBuilder
|
|
296
333
|
stb = SymbolTreeBuilder(op)
|
|
297
334
|
stree = stb.build()
|
|
298
335
|
replacer = AstReplacer(stree.get_class_ast())
|
|
299
336
|
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
|
300
|
-
return TreeNode.create_tree_node(stree, ast_node, new_targets,
|
|
337
|
+
return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op)
|
|
301
338
|
|
|
302
|
-
return Node.create_call_buildin_op(op, ast_node, new_targets,
|
|
339
|
+
return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name)
|
|
340
|
+
|
|
341
|
+
@classmethod
|
|
342
|
+
def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
343
|
+
func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
|
|
344
|
+
node_name: str = ""):
|
|
345
|
+
"""
|
|
346
|
+
Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
347
|
+
A `CallCell` node represents an invoking to cell-op.
|
|
348
|
+
A `CallPrimitive` node represents an invoking to primitive-op.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
352
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
353
|
+
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
354
|
+
func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
355
|
+
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
356
|
+
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
357
|
+
class.
|
|
358
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
359
|
+
`SymbolTree`. Name of node also used as field name in network class.
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
if not isinstance(op, (Cell, Primitive)):
|
|
363
|
+
raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
|
|
364
|
+
if isinstance(op, Cell):
|
|
365
|
+
node_type = NodeType.CallCell
|
|
366
|
+
else:
|
|
367
|
+
node_type = NodeType.CallPrimitive
|
|
368
|
+
return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op)
|
|
303
369
|
|
|
304
370
|
@staticmethod
|
|
305
371
|
def _get_construct_arg_names(parameters):
|
|
@@ -508,21 +574,23 @@ class Node:
|
|
|
508
574
|
"""
|
|
509
575
|
return self._next
|
|
510
576
|
|
|
511
|
-
def
|
|
577
|
+
def set_prev(self, node: 'Node'):
|
|
512
578
|
"""
|
|
513
|
-
|
|
579
|
+
Set previous node of current node.
|
|
514
580
|
|
|
515
581
|
Args:
|
|
516
|
-
node (
|
|
582
|
+
node (Node): Node to be set as previous node of current node.
|
|
583
|
+
"""
|
|
584
|
+
self._prev = node
|
|
517
585
|
|
|
518
|
-
|
|
519
|
-
|
|
586
|
+
def set_next(self, node: 'Node'):
|
|
587
|
+
"""
|
|
588
|
+
Set next node of current node.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
node (Node): Node to be set as next node of current node.
|
|
520
592
|
"""
|
|
521
|
-
|
|
522
|
-
return self.has_same_ast(node._ast_node)
|
|
523
|
-
if isinstance(node, ast.AST):
|
|
524
|
-
return id(self._ast_node) == id(node)
|
|
525
|
-
return False
|
|
593
|
+
self._next = node
|
|
526
594
|
|
|
527
595
|
def get_ast(self) -> Optional[ast.AST]:
|
|
528
596
|
"""
|
|
@@ -552,16 +620,24 @@ class Node:
|
|
|
552
620
|
"""Set the symbol tree to which node belongs."""
|
|
553
621
|
self._belong_tree = symbol_tree
|
|
554
622
|
|
|
623
|
+
def get_node_manager(self):
|
|
624
|
+
"""Get the NodeManager current node belongs to."""
|
|
625
|
+
return self._node_manager
|
|
626
|
+
|
|
627
|
+
def set_node_manager(self, node_manager):
|
|
628
|
+
"""Set NodeManager current node belongs."""
|
|
629
|
+
self._node_manager = node_manager
|
|
630
|
+
|
|
555
631
|
def isolate(self):
|
|
556
632
|
"""Link prev node to next node and isolate node from source code order list."""
|
|
557
|
-
origin_prev: Optional[Node] = self.
|
|
558
|
-
origin_next: Optional[Node] = self.
|
|
633
|
+
origin_prev: Optional[Node] = self.get_prev()
|
|
634
|
+
origin_next: Optional[Node] = self.get_next()
|
|
559
635
|
if origin_prev is not None:
|
|
560
|
-
origin_prev.
|
|
636
|
+
origin_prev.set_next(origin_next)
|
|
561
637
|
if origin_next is not None:
|
|
562
|
-
origin_next.
|
|
563
|
-
self.
|
|
564
|
-
self.
|
|
638
|
+
origin_next.set_prev(origin_prev)
|
|
639
|
+
self.set_prev(None)
|
|
640
|
+
self.set_next(None)
|
|
565
641
|
|
|
566
642
|
def insert_before(self, node: 'Node'):
|
|
567
643
|
"""
|
|
@@ -571,12 +647,12 @@ class Node:
|
|
|
571
647
|
node (Node): An instance of node to be inserted in.
|
|
572
648
|
"""
|
|
573
649
|
node.isolate()
|
|
574
|
-
origin_prev: Optional[Node] = self.
|
|
650
|
+
origin_prev: Optional[Node] = self.get_prev()
|
|
575
651
|
if origin_prev is not None:
|
|
576
|
-
origin_prev.
|
|
577
|
-
node.
|
|
578
|
-
node.
|
|
579
|
-
self.
|
|
652
|
+
origin_prev.set_next(node)
|
|
653
|
+
node.set_prev(origin_prev)
|
|
654
|
+
node.set_next(self)
|
|
655
|
+
self.set_prev(node)
|
|
580
656
|
|
|
581
657
|
def insert_after(self, node: 'Node'):
|
|
582
658
|
"""
|
|
@@ -586,12 +662,12 @@ class Node:
|
|
|
586
662
|
node (Node): An instance of node to be inserted in.
|
|
587
663
|
"""
|
|
588
664
|
node.isolate()
|
|
589
|
-
origin_next: Optional[Node] = self.
|
|
590
|
-
self.
|
|
591
|
-
node.
|
|
592
|
-
node.
|
|
665
|
+
origin_next: Optional[Node] = self.get_next()
|
|
666
|
+
self.set_next(node)
|
|
667
|
+
node.set_prev(self)
|
|
668
|
+
node.set_next(origin_next)
|
|
593
669
|
if origin_next is not None:
|
|
594
|
-
origin_next.
|
|
670
|
+
origin_next.set_prev(node)
|
|
595
671
|
|
|
596
672
|
def get_inputs(self) -> ['Node']:
|
|
597
673
|
"""
|
|
@@ -651,26 +727,26 @@ class Node:
|
|
|
651
727
|
NodeType.MathOps):
|
|
652
728
|
self._sync_assign_targets_to_ast()
|
|
653
729
|
|
|
654
|
-
def
|
|
730
|
+
def get_func_name(self) -> ScopedValue:
|
|
655
731
|
"""
|
|
656
|
-
Getter of `
|
|
732
|
+
Getter of `_func_name`. See detail in docstring of Node class for meaning of func.
|
|
657
733
|
|
|
658
734
|
Returns:
|
|
659
735
|
An instance of ScopedValue.
|
|
660
736
|
"""
|
|
661
|
-
return self.
|
|
737
|
+
return self._func_name
|
|
662
738
|
|
|
663
|
-
def
|
|
739
|
+
def set_func_name(self, func_name: ScopedValue):
|
|
664
740
|
"""
|
|
665
|
-
Setter of `
|
|
741
|
+
Setter of `_func_name`. See detail in docstring of Node class for meaning of func.
|
|
666
742
|
|
|
667
743
|
Note:
|
|
668
|
-
When `
|
|
744
|
+
When `_func_name` is updated, corresponding ast node would be updated also.
|
|
669
745
|
|
|
670
746
|
Args:
|
|
671
747
|
func (ScopedValue): An instance of ScopedValue as new func.
|
|
672
748
|
"""
|
|
673
|
-
self.
|
|
749
|
+
self._func_name = func_name
|
|
674
750
|
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
|
|
675
751
|
self._sync_assign_func_to_ast()
|
|
676
752
|
|
|
@@ -747,11 +823,11 @@ class Node:
|
|
|
747
823
|
Validator.check_value_type("node", node, [Node], "Node")
|
|
748
824
|
Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
|
|
749
825
|
if out_idx is None:
|
|
750
|
-
if len(node.
|
|
826
|
+
if len(node.get_targets()) != 1:
|
|
751
827
|
raise RuntimeError("node should has one output when out_idx is not provided")
|
|
752
828
|
out_idx = 0
|
|
753
|
-
Validator.check_int_range(out_idx, 0, len(node.
|
|
754
|
-
new_arg = node.
|
|
829
|
+
Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
|
|
830
|
+
new_arg = node.get_targets()[out_idx]
|
|
755
831
|
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
|
|
756
832
|
self._sync_arg()
|
|
757
833
|
|
|
@@ -943,18 +1019,36 @@ class Node:
|
|
|
943
1019
|
def get_arg_providers(self) -> dict:
|
|
944
1020
|
"""
|
|
945
1021
|
Getter of _arg_providers.
|
|
1022
|
+
|
|
1023
|
+
Return:
|
|
1024
|
+
dict, key is type of int indicating the index of args, and value is type of tuple, which includes
|
|
1025
|
+
the node and the index of node's targets who provides the argument.
|
|
946
1026
|
"""
|
|
947
1027
|
return self._arg_providers
|
|
948
1028
|
|
|
949
1029
|
def set_arg_providers(self, index: int, provider: tuple):
|
|
950
1030
|
"""
|
|
951
1031
|
Setter of _arg_providers.
|
|
1032
|
+
|
|
1033
|
+
Args:
|
|
1034
|
+
index (int): Indicating provider of which argument need to be set.
|
|
1035
|
+
provider (tuple): A tuple includes the node and the index of node's targets who provides the argument.
|
|
952
1036
|
"""
|
|
953
1037
|
self._arg_providers[index] = provider
|
|
954
1038
|
|
|
955
1039
|
def get_target_users(self, index=-1) -> Union[dict, list]:
|
|
956
1040
|
"""
|
|
957
1041
|
Getter of _target_users.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will
|
|
1045
|
+
be returned.
|
|
1046
|
+
|
|
1047
|
+
Return:
|
|
1048
|
+
Union[dict, list]. When index is not -1, a list of users of specified target will be returned.
|
|
1049
|
+
The type of elements in list is tuple, which includes the user node and the index of node's arguments
|
|
1050
|
+
who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the
|
|
1051
|
+
value is list of users of corresponding target.
|
|
958
1052
|
"""
|
|
959
1053
|
if index == -1:
|
|
960
1054
|
return self._target_users
|
|
@@ -965,11 +1059,23 @@ class Node:
|
|
|
965
1059
|
def append_target_users(self, index: int, provider: tuple):
|
|
966
1060
|
"""
|
|
967
1061
|
Setter of _target_users.
|
|
1062
|
+
|
|
1063
|
+
Args:
|
|
1064
|
+
index (int): Indicating users of which target need to be append.
|
|
1065
|
+
provider (tuple): A tuple includes the node and the index of node's argument who uses the target.
|
|
1066
|
+
|
|
968
1067
|
"""
|
|
969
1068
|
if index not in self._target_users.keys():
|
|
970
1069
|
self._target_users[index] = []
|
|
971
1070
|
self._target_users.get(index).append(provider)
|
|
972
1071
|
|
|
1072
|
+
def update_ast_node(self) -> ast.AST:
|
|
1073
|
+
"""Update node's ast_node by current targets, func_name, args and kwargs."""
|
|
1074
|
+
ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(),
|
|
1075
|
+
self.get_args(), self.get_kwargs())
|
|
1076
|
+
self.set_ast(ast_assign)
|
|
1077
|
+
return ast_assign
|
|
1078
|
+
|
|
973
1079
|
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
|
|
974
1080
|
"""
|
|
975
1081
|
Merge args and kwargs to normalized args.
|
|
@@ -1010,6 +1116,10 @@ class Node:
|
|
|
1010
1116
|
self._normalized_args_keys.append(arg_key)
|
|
1011
1117
|
return normalized_args
|
|
1012
1118
|
|
|
1119
|
+
##########################################################################################################
|
|
1120
|
+
# Synchronize rewrite node args to ast node
|
|
1121
|
+
##########################################################################################################
|
|
1122
|
+
|
|
1013
1123
|
def _sync_assign_func_to_ast(self):
|
|
1014
1124
|
"""Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
|
|
1015
1125
|
if self._ast_node is None:
|
|
@@ -1021,20 +1131,21 @@ class Node:
|
|
|
1021
1131
|
if not isinstance(call_ast, ast.Call):
|
|
1022
1132
|
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
|
|
1023
1133
|
func_ast = call_ast.func
|
|
1024
|
-
if not self.
|
|
1134
|
+
if not self._func_name.value:
|
|
1025
1135
|
if isinstance(func_ast, ast.Name):
|
|
1026
|
-
func_ast.id = self.
|
|
1136
|
+
func_ast.id = self._func_name.value
|
|
1027
1137
|
else:
|
|
1028
|
-
call_ast.func = ast.Name(self.
|
|
1138
|
+
call_ast.func = ast.Name(self._func_name.value, ast.Store())
|
|
1029
1139
|
else:
|
|
1030
1140
|
if isinstance(func_ast, ast.Attribute):
|
|
1031
1141
|
func_value = func_ast.value
|
|
1032
1142
|
if not isinstance(func_value, ast.Name):
|
|
1033
1143
|
raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
|
|
1034
|
-
func_value.id = self.
|
|
1035
|
-
func_ast.attr = self.
|
|
1144
|
+
func_value.id = self._func_name.scope
|
|
1145
|
+
func_ast.attr = self._func_name.value
|
|
1036
1146
|
else:
|
|
1037
|
-
call_ast.func = ast.Attribute(ast.Name(self.
|
|
1147
|
+
call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
|
|
1148
|
+
self._func_name.value, ast.Store())
|
|
1038
1149
|
ast.fix_missing_locations(assign_ast)
|
|
1039
1150
|
|
|
1040
1151
|
def _sync_assign_targets_to_ast(self):
|
|
@@ -1050,7 +1161,7 @@ class Node:
|
|
|
1050
1161
|
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
|
|
1051
1162
|
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
|
|
1052
1163
|
raise RuntimeError("self._targets should have targets_ast same length")
|
|
1053
|
-
for i in
|
|
1164
|
+
for i, _ in enumerate(self._targets):
|
|
1054
1165
|
target = self._targets[i]
|
|
1055
1166
|
target_ast = targets_ast[0]
|
|
1056
1167
|
if isinstance(target_ast, ast.Name):
|
|
@@ -1070,7 +1181,7 @@ class Node:
|
|
|
1070
1181
|
return
|
|
1071
1182
|
assign_ast = self._ast_node
|
|
1072
1183
|
if not isinstance(assign_ast, ast.Assign):
|
|
1073
|
-
raise TypeError("assign_ast should be ast.Assign, got:
|
|
1184
|
+
raise TypeError(f"assign_ast should be ast.Assign, got: {type(assign_ast)}")
|
|
1074
1185
|
assign_value = assign_ast.value
|
|
1075
1186
|
if not isinstance(assign_value, ast.Call):
|
|
1076
1187
|
return
|
|
@@ -1121,23 +1232,31 @@ class Node:
|
|
|
1121
1232
|
if len(self._normalized_args_keys) != 1:
|
|
1122
1233
|
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
|
1123
1234
|
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
|
1124
|
-
if arg.type
|
|
1125
|
-
raise RuntimeError("arg should be an
|
|
1235
|
+
if arg.type != ValueType.ConstantValue:
|
|
1236
|
+
raise RuntimeError("arg should be an ConstantValue")
|
|
1126
1237
|
if arg.scope != "":
|
|
1127
1238
|
raise RuntimeError("arg.scope should be empty")
|
|
1128
1239
|
assign_value.value = arg.value
|
|
1129
1240
|
|
|
1130
1241
|
def _sync_call_method_args_to_ast(self):
|
|
1131
|
-
"""
|
|
1242
|
+
"""
|
|
1243
|
+
Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
|
|
1244
|
+
|
|
1245
|
+
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1246
|
+
- ast.Tuple
|
|
1247
|
+
- ast.Name
|
|
1248
|
+
- ast.ast.Attribute
|
|
1249
|
+
- ...
|
|
1250
|
+
"""
|
|
1132
1251
|
if self._ast_node is None:
|
|
1133
1252
|
return
|
|
1134
1253
|
assign_ast = self._ast_node
|
|
1135
1254
|
if not isinstance(assign_ast, ast.Assign):
|
|
1136
1255
|
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
|
1137
1256
|
assign_value = assign_ast.value
|
|
1138
|
-
if self.
|
|
1257
|
+
if self._func_name == PASS_THROUGH_METHOD:
|
|
1139
1258
|
self._sync_call_pass_through_method_args_to_ast(assign_value)
|
|
1140
|
-
elif self.
|
|
1259
|
+
elif self._func_name.value == "tuple":
|
|
1141
1260
|
tuple_ast: ast.Tuple = assign_value
|
|
1142
1261
|
if len(self._normalized_args_keys) != len(tuple_ast.elts):
|
|
1143
1262
|
raise RuntimeError("size of self._normalized_args_keys should be equal to size of elements of tuple")
|
|
@@ -1157,10 +1276,16 @@ class Node:
|
|
|
1157
1276
|
else:
|
|
1158
1277
|
raise RuntimeError("Only support constant or symbol in tuple now")
|
|
1159
1278
|
else:
|
|
1160
|
-
raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self.
|
|
1279
|
+
raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self._func_name.value)
|
|
1161
1280
|
|
|
1162
1281
|
def _sync_return_node_to_ast(self):
|
|
1163
|
-
"""
|
|
1282
|
+
"""
|
|
1283
|
+
Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
|
|
1284
|
+
|
|
1285
|
+
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1286
|
+
- ast.Name
|
|
1287
|
+
- ast.Tuple
|
|
1288
|
+
"""
|
|
1164
1289
|
if self._ast_node is None:
|
|
1165
1290
|
return
|
|
1166
1291
|
return_ast = self._ast_node
|
|
@@ -1222,7 +1347,7 @@ class Node:
|
|
|
1222
1347
|
|
|
1223
1348
|
def _sync_arg(self):
|
|
1224
1349
|
"""Sync _normalized_args to corresponding ast node when updated."""
|
|
1225
|
-
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree
|
|
1350
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
|
|
1226
1351
|
NodeType.CellContainer, NodeType.CallFunction):
|
|
1227
1352
|
self._sync_call_cell_args_to_ast()
|
|
1228
1353
|
elif self._node_type == NodeType.Output:
|
|
@@ -1233,15 +1358,18 @@ class Node:
|
|
|
1233
1358
|
self._sync_mathops_node_args_to_ast()
|
|
1234
1359
|
|
|
1235
1360
|
|
|
1361
|
+
##########################################################################################################
|
|
1362
|
+
# Child classes
|
|
1363
|
+
##########################################################################################################
|
|
1364
|
+
|
|
1236
1365
|
class TreeNode(Node):
|
|
1237
1366
|
"""Tree type Node who holds a handler of SymbolTree."""
|
|
1238
1367
|
|
|
1239
1368
|
def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
|
1240
1369
|
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
|
1241
1370
|
"""
|
|
1242
|
-
Constructor of
|
|
1243
|
-
as `
|
|
1244
|
-
`create_output_node`, etc. rather than invoking constructor of Node directly.
|
|
1371
|
+
Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of
|
|
1372
|
+
TreeNode such as `create_tree_node` rather than invoking constructor of Node directly.
|
|
1245
1373
|
|
|
1246
1374
|
Args:
|
|
1247
1375
|
tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
|
|
@@ -1260,8 +1388,9 @@ class TreeNode(Node):
|
|
|
1260
1388
|
self.symbol_tree = tree
|
|
1261
1389
|
|
|
1262
1390
|
@classmethod
|
|
1263
|
-
def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
|
|
1264
|
-
args: [ScopedValue], kwargs: {str: ScopedValue},
|
|
1391
|
+
def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
|
|
1392
|
+
func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue},
|
|
1393
|
+
name: str = "", instance=None):
|
|
1265
1394
|
"""
|
|
1266
1395
|
Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking
|
|
1267
1396
|
to sub-network.
|
|
@@ -1270,104 +1399,14 @@ class TreeNode(Node):
|
|
|
1270
1399
|
tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
|
|
1271
1400
|
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1272
1401
|
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1273
|
-
|
|
1402
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1274
1403
|
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1275
1404
|
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1276
1405
|
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1277
1406
|
Name of node also used as field name in network class.
|
|
1278
1407
|
instance: Object in network corresponding to this node.
|
|
1279
1408
|
"""
|
|
1280
|
-
|
|
1281
|
-
non_custom_args = Node._handle_custom_obj_in_args(args)
|
|
1282
|
-
non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
|
|
1283
1409
|
new_targets = Node._handle_targets(targets)
|
|
1284
|
-
if isinstance(
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
|
|
1288
|
-
return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
class CellContainer(Node):
|
|
1292
|
-
""" Container for saving cell-objects node. """
|
|
1293
|
-
class _Visitor():
|
|
1294
|
-
""" A iterator of CellContainer nodes. """
|
|
1295
|
-
def __init__(self, cellcontainer):
|
|
1296
|
-
self._cellcontainer = cellcontainer
|
|
1297
|
-
|
|
1298
|
-
def __len__(self):
|
|
1299
|
-
""" Get the number of nodes. """
|
|
1300
|
-
return self._cellcontainer.node_count
|
|
1301
|
-
|
|
1302
|
-
def __iter__(self):
|
|
1303
|
-
"""Create an iterator over the CellContainer."""
|
|
1304
|
-
count = len(self._cellcontainer.node_list)
|
|
1305
|
-
i = 0
|
|
1306
|
-
while i < count:
|
|
1307
|
-
curr = self._cellcontainer.node_list[i]
|
|
1308
|
-
if curr.valid:
|
|
1309
|
-
yield curr
|
|
1310
|
-
i += 1
|
|
1311
|
-
|
|
1312
|
-
def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
|
1313
|
-
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
|
1314
|
-
"""Constructor of CellContainer.
|
|
1315
|
-
|
|
1316
|
-
Args:
|
|
1317
|
-
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1318
|
-
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1319
|
-
func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1320
|
-
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1321
|
-
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1322
|
-
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1323
|
-
Name of node also used as field name in network class.
|
|
1324
|
-
instance: Object in network corresponding to this node.
|
|
1325
|
-
"""
|
|
1326
|
-
if isinstance(func, str):
|
|
1327
|
-
func = ScopedValue.create_naming_value(func)
|
|
1328
|
-
super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
|
|
1329
|
-
self._node_list = list()
|
|
1330
|
-
self._node_count = 0
|
|
1331
|
-
|
|
1332
|
-
@property
|
|
1333
|
-
def node_count(self):
|
|
1334
|
-
"""Number of nodes."""
|
|
1335
|
-
return len(self._node_list)
|
|
1336
|
-
|
|
1337
|
-
@property
|
|
1338
|
-
def node_list(self):
|
|
1339
|
-
""" Get node list. """
|
|
1340
|
-
return self._node_list
|
|
1341
|
-
|
|
1342
|
-
def append(self, node):
|
|
1343
|
-
""" Append new node to node list. """
|
|
1344
|
-
setattr(node, "container", self)
|
|
1345
|
-
setattr(node, "valid", True)
|
|
1346
|
-
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1347
|
-
self._node_list.append(node)
|
|
1348
|
-
# when creating a cell_container, node instance is already in SequentialCell cell_list
|
|
1349
|
-
# so here we need to write a if judgement
|
|
1350
|
-
if node.get_instance() not in self.get_instance().cell_list:
|
|
1351
|
-
self.get_instance().append(node.get_instance())
|
|
1352
|
-
|
|
1353
|
-
def erase(self, node):
|
|
1354
|
-
"""Erase node form container."""
|
|
1355
|
-
index_node = self.node_list.index(node)
|
|
1356
|
-
index_instance = self.get_instance().cell_list.index(node.get_instance())
|
|
1357
|
-
if index_node != index_instance:
|
|
1358
|
-
raise RuntimeError("In MindSpore Rewrite CellContainer, erasing a node raises index error!!!")
|
|
1359
|
-
setattr(node, "valid", False)
|
|
1360
|
-
del self.get_instance()[index_node]
|
|
1361
|
-
del self._node_list[index_node]
|
|
1362
|
-
|
|
1363
|
-
def insert(self, index, node):
|
|
1364
|
-
"""Insert node into container"""
|
|
1365
|
-
self.node_list.insert(index, node)
|
|
1366
|
-
setattr(node, "container", self)
|
|
1367
|
-
setattr(node, "valid", True)
|
|
1368
|
-
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1369
|
-
self.get_instance()._insert(index, node.get_instance())
|
|
1370
|
-
|
|
1371
|
-
def nodes(self):
|
|
1372
|
-
""" Return a iterator of node."""
|
|
1373
|
-
return self._Visitor(self)
|
|
1410
|
+
if isinstance(func_name, str):
|
|
1411
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
1412
|
+
return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance)
|