mindspore 2.4.0__cp310-cp310-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-310-darwin.so +0 -0
- mindspore/_c_expression.cpython-310-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-310-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1819 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
|
16
|
+
import stat
|
|
17
|
+
from typing import Optional, Union, Tuple, Any, Dict, List
|
|
18
|
+
import types
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import ast
|
|
22
|
+
import importlib.util
|
|
23
|
+
import time
|
|
24
|
+
import inspect
|
|
25
|
+
from textwrap import dedent
|
|
26
|
+
from collections import OrderedDict
|
|
27
|
+
|
|
28
|
+
from mindspore.nn import Cell
|
|
29
|
+
from mindspore import log as logger
|
|
30
|
+
from .symbol_tree_dumper import SymbolTreeDumper
|
|
31
|
+
from ..node import Node, TreeNode, ControlFlow, CallFunction, NodeManager
|
|
32
|
+
from ..api.node_type import NodeType
|
|
33
|
+
from ..api.scoped_value import ScopedValue, ValueType
|
|
34
|
+
from ..ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder, \
|
|
35
|
+
AstImportFinder
|
|
36
|
+
from ..common.namer import TargetNamer, NodeNamer, ClassNamer
|
|
37
|
+
from ..common.observer import Observer
|
|
38
|
+
from ..common.observable import Observable
|
|
39
|
+
from ..common.event import Event
|
|
40
|
+
|
|
41
|
+
if sys.version_info >= (3, 9):
|
|
42
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
43
|
+
else:
|
|
44
|
+
import astunparse
|
|
45
|
+
|
|
46
|
+
class Position:
|
|
47
|
+
"""
|
|
48
|
+
Position indicates a source code position in one network.
|
|
49
|
+
|
|
50
|
+
Rewrite recommend using class method `create()` of position rather than constructor of Position.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
symbol_tree (SymbolTree): A handler of SymbolTree indicated position in which SymbolTree.
|
|
54
|
+
node (Node): A handler of Node indicated position is around which Node.
|
|
55
|
+
before_node (bool): A bool indicated position is before or after the 'node'.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, symbol_tree, node, before_node: bool):
|
|
59
|
+
self.symbol_tree = symbol_tree
|
|
60
|
+
self.node = node
|
|
61
|
+
self.before_node = before_node
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def create(cls, symbol_tree, node, before_node):
|
|
65
|
+
"""
|
|
66
|
+
Class method of Position. Return None when symbol_tree or node is None.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
symbol_tree: A handler of SymbolTree indicated position in which SymbolTree.
|
|
70
|
+
node: A handler of Node indicated position is around which Node.
|
|
71
|
+
before_node (bool): A bool indicated position is before or after the 'node'.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
A Position.
|
|
75
|
+
"""
|
|
76
|
+
if symbol_tree is None or node is None:
|
|
77
|
+
return None
|
|
78
|
+
return Position(symbol_tree, node, before_node)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class FieldFinder(AstFinder):
|
|
82
|
+
"""
|
|
83
|
+
Check whether field exist in specific scope.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
scope (ast.AST): An instance of ast node as search scope.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, scope: ast.AST):
|
|
90
|
+
super().__init__(scope)
|
|
91
|
+
self._result = False
|
|
92
|
+
self._field_name = ""
|
|
93
|
+
|
|
94
|
+
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
|
95
|
+
"""Visit a node of type ast.Attribute."""
|
|
96
|
+
value = node.value
|
|
97
|
+
if not isinstance(value, ast.Name):
|
|
98
|
+
return super(FieldFinder, self).generic_visit(node)
|
|
99
|
+
if value.id != "self":
|
|
100
|
+
return super(FieldFinder, self).generic_visit(node)
|
|
101
|
+
if node.attr == self._field_name:
|
|
102
|
+
self._result = True
|
|
103
|
+
return super(FieldFinder, self).generic_visit(node)
|
|
104
|
+
|
|
105
|
+
def check(self, field) -> bool:
|
|
106
|
+
"""
|
|
107
|
+
Check whether `field` exist in scope.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
field (str): A string indicates target field name.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A bool indicate whether `field` exist in scope.
|
|
114
|
+
"""
|
|
115
|
+
self._result = False
|
|
116
|
+
self._field_name = field
|
|
117
|
+
self.visit(self._scope)
|
|
118
|
+
return self._result
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class SymbolTree(Observer, Observable, NodeManager):
|
|
122
|
+
"""
|
|
123
|
+
A symbol-tree usually corresponding to forward method of a network.
|
|
124
|
+
|
|
125
|
+
Rewrite recommend using SymbolTreeBuilder to instantiate an instance of SymbolTree rather than invoking constructor
|
|
126
|
+
of SymbolTree directly.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
origin_network (Cell): A handler to original network instance.
|
|
130
|
+
module_ast (ast.Module): An instance of ast.AST represents ast node of original network.
|
|
131
|
+
"""
|
|
132
|
+
# whether parse CallFunction node inserted by user.
|
|
133
|
+
_unparse_inserted_function = True
|
|
134
|
+
|
|
135
|
+
def __init__(self, origin_network: Cell, module_ast: ast.Module):
|
|
136
|
+
Observer.__init__(self)
|
|
137
|
+
Observable.__init__(self)
|
|
138
|
+
self._node_namer = NodeNamer()
|
|
139
|
+
self._node_namer.add_name('obj')
|
|
140
|
+
NodeManager.__init__(self)
|
|
141
|
+
NodeManager.set_manager_node_namer(self, self._node_namer)
|
|
142
|
+
NodeManager.reg_observer(self, observer=self)
|
|
143
|
+
# init unique-namers
|
|
144
|
+
self._target_namer = TargetNamer()
|
|
145
|
+
# input arguments of function
|
|
146
|
+
self._ori_cls_name = type(origin_network).__name__
|
|
147
|
+
self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
|
|
148
|
+
NodeManager.set_manager_name(self, self._opt_cls_name)
|
|
149
|
+
self._origin_network = origin_network
|
|
150
|
+
self._module_ast: ast.Module = module_ast
|
|
151
|
+
self._import_asts: Optional[ast.Ast] = []
|
|
152
|
+
self._class_ast: Optional[ast.ClassDef] = None
|
|
153
|
+
self._root_ast: Optional[ast.FunctionDef] = None
|
|
154
|
+
self._init_func_ast: Optional[ast.FunctionDef] = None
|
|
155
|
+
self._deleted_field = {}
|
|
156
|
+
self._deleted_node = []
|
|
157
|
+
# {ast_function: [import_asts]}
|
|
158
|
+
self._external_ast: Dict[ast.FunctionDef, list] = OrderedDict()
|
|
159
|
+
# {ast_class: [import_asts]}
|
|
160
|
+
self._father_class_ast: Dict[ast.ClassDef, list] = OrderedDict()
|
|
161
|
+
self._modified = False
|
|
162
|
+
self._saved_file_name = "./network_define.py"
|
|
163
|
+
# used to insert "sys.path.append(xxx)"
|
|
164
|
+
self._net_file_paths = []
|
|
165
|
+
self._tmp_import_strs = []
|
|
166
|
+
self._tmp_unmodified_strees: {type, List[SymbolTree]} = {}
|
|
167
|
+
self._tmp_replacers = []
|
|
168
|
+
# user custom codes
|
|
169
|
+
self._custom_codes: List[ast.AST] = []
|
|
170
|
+
# local primitive instances initialized during forward method, e.g. abs_inst = P.Abs()
|
|
171
|
+
self._local_prim_inits: List[Node] = []
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def _remove_unused_import(module_ast):
|
|
175
|
+
"""remove unused import in self._module_ast"""
|
|
176
|
+
import_nodes: List[Union[ast.Import, ast.ImportFrom]] = []
|
|
177
|
+
|
|
178
|
+
def is_divider(ast_node):
|
|
179
|
+
"""judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
|
|
180
|
+
return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
|
|
181
|
+
|
|
182
|
+
for ast_node in module_ast.body[:]:
|
|
183
|
+
if isinstance(ast_node, (ast.Import, ast.ImportFrom)):
|
|
184
|
+
import_nodes.append(ast_node)
|
|
185
|
+
if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
|
|
186
|
+
str_checker = StrChecker(ast_node)
|
|
187
|
+
for import_node in import_nodes:
|
|
188
|
+
for alias in import_node.names[:]:
|
|
189
|
+
name = alias.asname if alias.asname else alias.name
|
|
190
|
+
if name == '*':
|
|
191
|
+
continue
|
|
192
|
+
if not str_checker.check(name):
|
|
193
|
+
import_node.names.remove(alias)
|
|
194
|
+
if not import_node.names:
|
|
195
|
+
module_ast.body.remove(import_node)
|
|
196
|
+
if is_divider(ast_node):
|
|
197
|
+
import_nodes.clear()
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def _remove_duplicated_import(module_ast):
|
|
201
|
+
"""Remove duplicated import of 'net'."""
|
|
202
|
+
imports = set()
|
|
203
|
+
futures = set()
|
|
204
|
+
names = set()
|
|
205
|
+
|
|
206
|
+
class TransImportNode(ast.NodeTransformer):
|
|
207
|
+
"""Find all import nodes from input ast node."""
|
|
208
|
+
|
|
209
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
|
210
|
+
if node.name not in names:
|
|
211
|
+
names.add(node.name)
|
|
212
|
+
return node
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
|
216
|
+
if node.name not in names:
|
|
217
|
+
names.add(node.name)
|
|
218
|
+
return node
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
def visit_Try(self, node: ast.Try) -> Any:
|
|
222
|
+
if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
|
|
223
|
+
import_str = astunparse.unparse(node)
|
|
224
|
+
if import_str not in imports:
|
|
225
|
+
imports.add(import_str)
|
|
226
|
+
return node
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
def visit_Import(self, node: ast.Import) -> Any:
|
|
230
|
+
import_str = astunparse.unparse(node)
|
|
231
|
+
if import_str not in imports:
|
|
232
|
+
imports.add(import_str)
|
|
233
|
+
return node
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
|
237
|
+
"""
|
|
238
|
+
Once the father class 'A' is defined in the current module, all the next imported class 'A' should
|
|
239
|
+
be removed. e.g.
|
|
240
|
+
def class A():
|
|
241
|
+
...
|
|
242
|
+
from xxx import A, B
|
|
243
|
+
=>
|
|
244
|
+
def class A():
|
|
245
|
+
...
|
|
246
|
+
from xxx import B
|
|
247
|
+
"""
|
|
248
|
+
import_str = astunparse.unparse(node)
|
|
249
|
+
|
|
250
|
+
if import_str not in imports:
|
|
251
|
+
imports.add(import_str)
|
|
252
|
+
# remove "__future__" module
|
|
253
|
+
if node.module == '__future__':
|
|
254
|
+
futures.add(node.module)
|
|
255
|
+
return None
|
|
256
|
+
# remove modules which have been defined in the code file
|
|
257
|
+
# it occurs when class A is a father class and other sub-classes import A
|
|
258
|
+
for alias in node.names[:]:
|
|
259
|
+
if alias.name in names:
|
|
260
|
+
node.names.remove(alias)
|
|
261
|
+
# if the alias(es) in node.names are all removed, this import statement should be removed
|
|
262
|
+
if not node.names:
|
|
263
|
+
return None
|
|
264
|
+
return node
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
get_node_handler = TransImportNode()
|
|
268
|
+
get_node_handler.generic_visit(module_ast)
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def _remove_arg_annotations(module_ast):
|
|
272
|
+
"""Remove annotations in ast.arg to avoid 'xxx is not defined'."""
|
|
273
|
+
ast_args: List[ast.arg] = AstFinder(module_ast).find_all(ast.arg)
|
|
274
|
+
for ast_arg in ast_args:
|
|
275
|
+
ast_arg.annotation = None
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def _check_import(import_path: str, import_module: str):
|
|
279
|
+
"""
|
|
280
|
+
Check whether import operation is valid when importing module from specific path.
|
|
281
|
+
"""
|
|
282
|
+
if import_path not in sys.path:
|
|
283
|
+
sys.path.append(import_path)
|
|
284
|
+
try:
|
|
285
|
+
importlib.import_module(name=import_module)
|
|
286
|
+
except (ValueError, ImportError) as e:
|
|
287
|
+
logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
|
|
288
|
+
return False
|
|
289
|
+
except Exception as e: # pylint: disable=W0703
|
|
290
|
+
logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
|
|
291
|
+
return False
|
|
292
|
+
return True
|
|
293
|
+
|
|
294
|
+
@staticmethod
|
|
295
|
+
def _process_relative_import(import_node: Union[ast.Import, ast.ImportFrom], file_path: str):
|
|
296
|
+
"""Process relative imports"""
|
|
297
|
+
file_path = os.path.normcase(file_path)
|
|
298
|
+
file_path = os.path.normpath(file_path)
|
|
299
|
+
if isinstance(import_node, ast.ImportFrom):
|
|
300
|
+
# pad the ImportFrom with parent path
|
|
301
|
+
# e.g. from ..C import xxx -> from A.B.C import xxx
|
|
302
|
+
import_module = SymbolTree._get_valid_import_info(import_node, file_path)
|
|
303
|
+
if import_module:
|
|
304
|
+
import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0)
|
|
305
|
+
return import_node
|
|
306
|
+
|
|
307
|
+
@staticmethod
|
|
308
|
+
def _get_valid_import_info(import_node: ast.ImportFrom, file_path: str):
|
|
309
|
+
"""Get valid import info while import_node.module is at form of relative path"""
|
|
310
|
+
file_path = os.path.dirname(os.path.realpath(file_path))
|
|
311
|
+
# get real path from import_node.level
|
|
312
|
+
# from .(A) import xxx: current path
|
|
313
|
+
# from ..(A) import xxx: last level path
|
|
314
|
+
level = import_node.level
|
|
315
|
+
# from A import xxx: it does not need to pad, directly return the module name
|
|
316
|
+
if level == 0:
|
|
317
|
+
return import_node.module
|
|
318
|
+
if level > 1:
|
|
319
|
+
for _ in range(level - 1):
|
|
320
|
+
file_path = os.path.dirname(file_path)
|
|
321
|
+
file_path_tmp = file_path[:]
|
|
322
|
+
max_level_count = file_path.count(os.path.sep) - 1
|
|
323
|
+
level_count = 0
|
|
324
|
+
# suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx'
|
|
325
|
+
suffix = ''
|
|
326
|
+
if import_node.module:
|
|
327
|
+
suffix = '.' + import_node.module
|
|
328
|
+
while level_count < max_level_count:
|
|
329
|
+
file_path_tmp = os.path.dirname(file_path_tmp)
|
|
330
|
+
if file_path_tmp not in sys.path:
|
|
331
|
+
logger.debug(f"{file_path_tmp} not in sys.path, try upper level.")
|
|
332
|
+
level_count += 1
|
|
333
|
+
continue
|
|
334
|
+
import_module = file_path[len(file_path_tmp) + 1:].replace(os.path.sep, '.') + suffix
|
|
335
|
+
if SymbolTree._check_import(file_path_tmp, import_module):
|
|
336
|
+
# try test code success
|
|
337
|
+
return import_module
|
|
338
|
+
# test import ast failed, try upper level
|
|
339
|
+
level_count += 1
|
|
340
|
+
logger.info(f"Try upper level.")
|
|
341
|
+
# try codes with all level failed
|
|
342
|
+
logger.info(f"Test import code: {astunparse.unparse(import_node).strip()} failed, ignore this import code.")
|
|
343
|
+
return None
|
|
344
|
+
|
|
345
|
+
@staticmethod
|
|
346
|
+
def insert_to_ast_while_insert_input(new_node: Node, node_manager: NodeManager):
|
|
347
|
+
"""update ast when inserting NodeType.Input node"""
|
|
348
|
+
if not isinstance(node_manager, (SymbolTree, CallFunction)):
|
|
349
|
+
raise ValueError(f"Only support insert Input node into a SymbolTree or a node with type of "
|
|
350
|
+
f"CallFunction, but get {type(node_manager)}")
|
|
351
|
+
# insert a new input
|
|
352
|
+
node_manager.get_input_nodes().append(new_node)
|
|
353
|
+
ast_function: ast.FunctionDef = node_manager.get_manager_ast()
|
|
354
|
+
arg: str = new_node.get_targets()[0].value
|
|
355
|
+
ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
|
|
356
|
+
AstModifier.append_arg_to_function(ast_function, ast_arg)
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def insert_to_ast_while_insert_cell_primitive(new_node: Node, base_node: Node, before_node: bool,
|
|
360
|
+
node_manager: NodeManager, stree):
|
|
361
|
+
"""update ast when inserting NodeType.CallCell or NodeType.CallPrimitive node"""
|
|
362
|
+
# create a new assign statement
|
|
363
|
+
ast_assign = new_node.get_ast()
|
|
364
|
+
if ast_assign is None:
|
|
365
|
+
func_name = stree.unique_func_name(new_node.get_name())
|
|
366
|
+
new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
|
|
367
|
+
ast_assign = new_node.update_ast_node()
|
|
368
|
+
if not isinstance(ast_assign, ast.Assign):
|
|
369
|
+
raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
|
|
370
|
+
# Save instance into _origin_network.
|
|
371
|
+
setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
|
|
372
|
+
# Insert ast to __init__ function
|
|
373
|
+
if isinstance(new_node, TreeNode):
|
|
374
|
+
init_code = f"{new_node.get_func_name()} = " \
|
|
375
|
+
f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
|
|
376
|
+
else:
|
|
377
|
+
init_code = f"{new_node.get_func_name()} = obj.{new_node.get_name()}"
|
|
378
|
+
init_ast = ast.parse(init_code).body[0]
|
|
379
|
+
AstModifier.insert_ast_to_function(stree.get_init_func_ast(), init_ast)
|
|
380
|
+
# Insert ast to construct_function/class_internal_function
|
|
381
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
382
|
+
ast_node_manager = node_manager.get_manager_ast()
|
|
383
|
+
if not ast_node_manager:
|
|
384
|
+
raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
|
|
385
|
+
"when inserting the ast.")
|
|
386
|
+
AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def insert_to_ast_while_insert_function(new_node: CallFunction, base_node: Node, before_node: bool,
|
|
390
|
+
node_manager: NodeManager, stree: 'SymbolTree'):
|
|
391
|
+
"""update ast when inserting NodeType.CallFunction node"""
|
|
392
|
+
func_name = str(new_node.get_func_name())
|
|
393
|
+
# create a new assign statement
|
|
394
|
+
ast_assign = new_node.get_ast()
|
|
395
|
+
if ast_assign is None:
|
|
396
|
+
ast_assign = new_node.update_ast_node()
|
|
397
|
+
# Insert ast to node_manager
|
|
398
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
399
|
+
ast_node_manager = node_manager.get_manager_ast()
|
|
400
|
+
if not ast_node_manager:
|
|
401
|
+
raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
|
|
402
|
+
"when inserting the ast.")
|
|
403
|
+
AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
|
|
404
|
+
# Ignore Python builtin functions
|
|
405
|
+
func_obj = new_node.get_instance()
|
|
406
|
+
if isinstance(func_obj, types.BuiltinFunctionType):
|
|
407
|
+
logger.warning(f"Ignore built in function: {func_name}")
|
|
408
|
+
return
|
|
409
|
+
# get ast.FunctionDef
|
|
410
|
+
source_code = inspect.getsource(func_obj)
|
|
411
|
+
ast_functiondef = ast.parse(dedent(source_code)).body[0]
|
|
412
|
+
if SymbolTree._unparse_inserted_function or not isinstance(ast_functiondef, ast.FunctionDef):
|
|
413
|
+
logger.debug(f"import '{func_name}' to access function object")
|
|
414
|
+
# add import to make sure that the function object can be accessed.
|
|
415
|
+
module = inspect.getmodule(func_obj)
|
|
416
|
+
top_node_manager = node_manager.get_top_manager()
|
|
417
|
+
belonging_ast = None if isinstance(top_node_manager, SymbolTree) else top_node_manager.get_manager_ast()
|
|
418
|
+
stree.add_import(module, func_name, belonging_ast)
|
|
419
|
+
return
|
|
420
|
+
# parse nodes in inserted function.
|
|
421
|
+
new_node.set_manager_ast(ast_functiondef)
|
|
422
|
+
new_node.set_manager_node_namer(stree.get_node_namer())
|
|
423
|
+
stree.get_external_ast()[ast_functiondef] = []
|
|
424
|
+
# import module which function defined in
|
|
425
|
+
func_file_path = inspect.getabsfile(func_obj)
|
|
426
|
+
stree.save_imports_from_file(func_file_path, ast_functiondef)
|
|
427
|
+
# expand ast codes in function
|
|
428
|
+
from ..ast_helpers import AstFlattener
|
|
429
|
+
ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name], stree)
|
|
430
|
+
# parse ast codes into CallFunction Node
|
|
431
|
+
from ..parsers import ParserRegister
|
|
432
|
+
parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
433
|
+
parser.process(stree, ast_functiondef, node_manager=new_node)
|
|
434
|
+
|
|
435
|
+
@staticmethod
|
|
436
|
+
def insert_to_ast_while_insert_node(new_node: Node, base_node: Node, before_node: bool):
|
|
437
|
+
""" insert_to_ast_while_insert_node. """
|
|
438
|
+
stree = new_node.get_belong_symbol_tree()
|
|
439
|
+
if not stree:
|
|
440
|
+
raise ValueError(f"When inserting node to ast, the belonging symbol tree of new_node is None.")
|
|
441
|
+
node_manager = new_node.get_node_manager()
|
|
442
|
+
if not isinstance(node_manager, (SymbolTree, CallFunction, ControlFlow)):
|
|
443
|
+
raise ValueError(f"When inserting node to ast, the node_manager of new_node {new_node.get_name()} can "
|
|
444
|
+
f"only be one of [SymbolTree, CallFunction, ControlFlow], but get {type(node_manager)}")
|
|
445
|
+
if new_node.get_node_type() == NodeType.Input:
|
|
446
|
+
SymbolTree.insert_to_ast_while_insert_input(new_node, node_manager)
|
|
447
|
+
elif new_node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
|
448
|
+
SymbolTree.insert_to_ast_while_insert_cell_primitive(new_node, base_node, before_node, node_manager,
|
|
449
|
+
stree)
|
|
450
|
+
elif new_node.get_node_type() == NodeType.CallFunction:
|
|
451
|
+
SymbolTree.insert_to_ast_while_insert_function(new_node, base_node, before_node, node_manager, stree)
|
|
452
|
+
else:
|
|
453
|
+
raise ValueError(f"When insert node '{new_node.get_name()}' into ast, the type of node can only be "
|
|
454
|
+
f"one of [Input, CallCell, CallPrimitive, CallFunction, Tree], but got "
|
|
455
|
+
f"{new_node.get_node_type()}.")
|
|
456
|
+
|
|
457
|
+
@staticmethod
|
|
458
|
+
def get_node_full_name(node: Node) -> str:
|
|
459
|
+
"""Get full name of node"""
|
|
460
|
+
name = node.get_manager_name() if isinstance(node, NodeManager) else node.get_name()
|
|
461
|
+
# traverse node_manager with type of Node
|
|
462
|
+
node_manager = node.get_node_manager()
|
|
463
|
+
while isinstance(node_manager, Node):
|
|
464
|
+
name = f"{node_manager.get_manager_name()}.{name}"
|
|
465
|
+
node_manager = node_manager.get_node_manager()
|
|
466
|
+
# type of node_manager is SymbolTree now
|
|
467
|
+
name = f"{node_manager.get_manager_name()}.{name}"
|
|
468
|
+
return name
|
|
469
|
+
|
|
470
|
+
def local_prim_inits(self) -> List[Node]:
|
|
471
|
+
"""get local primitives constructed during forward method"""
|
|
472
|
+
return self._local_prim_inits
|
|
473
|
+
|
|
474
|
+
def finish_build(self):
|
|
475
|
+
"""Add Event.TopologicalChangeEvent event when build is finished."""
|
|
476
|
+
self.add_event(Event.TopologicalChangeEvent)
|
|
477
|
+
|
|
478
|
+
def get_ori_cls_name(self) -> str:
|
|
479
|
+
"""
|
|
480
|
+
Get class name of original network.
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
A str represents class name of original network.
|
|
484
|
+
"""
|
|
485
|
+
return self._ori_cls_name
|
|
486
|
+
|
|
487
|
+
def get_opt_cls_name(self) -> str:
|
|
488
|
+
"""
|
|
489
|
+
Get class name of rewritten network.
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
A str represents class name of rewritten network.
|
|
493
|
+
"""
|
|
494
|
+
return self._opt_cls_name
|
|
495
|
+
|
|
496
|
+
def get_module_ast(self):
|
|
497
|
+
"""
|
|
498
|
+
Getter of `_module_ast`.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
An instance of ast.AST represents ast node of corresponding module.
|
|
502
|
+
"""
|
|
503
|
+
return self._module_ast
|
|
504
|
+
|
|
505
|
+
def set_module_ast(self, ast_node: ast.Module):
|
|
506
|
+
"""
|
|
507
|
+
Setter of _module_ast.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
ast_node (ast.Module): An instance of ast.Module represents ast node of module of corresponding network
|
|
511
|
+
class.
|
|
512
|
+
"""
|
|
513
|
+
self._module_ast = ast_node
|
|
514
|
+
|
|
515
|
+
def get_ast_root(self):
|
|
516
|
+
"""
|
|
517
|
+
Getter of `_root_ast`.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
An instance of ast.AST represents ast node of corresponding forward method.
|
|
521
|
+
"""
|
|
522
|
+
return self._root_ast
|
|
523
|
+
|
|
524
|
+
def set_ast_root(self, ast_node: ast.FunctionDef):
|
|
525
|
+
"""
|
|
526
|
+
Setter of _root_ast.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of forward method of
|
|
530
|
+
corresponding network class.
|
|
531
|
+
"""
|
|
532
|
+
self._root_ast = ast_node
|
|
533
|
+
NodeManager.set_manager_ast(self, ast_node)
|
|
534
|
+
|
|
535
|
+
def get_class_ast(self):
|
|
536
|
+
"""
|
|
537
|
+
Getter of `_class_ast`.
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
An instance of ast.ClassDef represents ast node of corresponding network class.
|
|
541
|
+
"""
|
|
542
|
+
return self._class_ast
|
|
543
|
+
|
|
544
|
+
def set_class_ast(self, ast_node: ast.ClassDef):
|
|
545
|
+
"""
|
|
546
|
+
Setter of `_class_ast`.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class.
|
|
550
|
+
"""
|
|
551
|
+
self._class_ast = ast_node
|
|
552
|
+
|
|
553
|
+
def get_init_func_ast(self):
|
|
554
|
+
"""
|
|
555
|
+
Getter of _init_func_ast.
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
An instance of ast.FunctionDef represents ast node of init method of corresponding network class.
|
|
559
|
+
"""
|
|
560
|
+
return self._init_func_ast
|
|
561
|
+
|
|
562
|
+
def set_init_func_ast(self, ast_node: ast.FunctionDef):
|
|
563
|
+
"""
|
|
564
|
+
Setter of _init_func_ast.
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of init method of
|
|
568
|
+
corresponding network class.
|
|
569
|
+
"""
|
|
570
|
+
self._init_func_ast = ast_node
|
|
571
|
+
|
|
572
|
+
def get_origin_network(self):
|
|
573
|
+
"""
|
|
574
|
+
Getter of `_origin_network`.
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
An instance of Cell which represents original network.
|
|
578
|
+
"""
|
|
579
|
+
return self._origin_network
|
|
580
|
+
|
|
581
|
+
def get_nodes_dict(self):
|
|
582
|
+
"""Get dict of nodes"""
|
|
583
|
+
return self._nodes
|
|
584
|
+
|
|
585
|
+
def get_node_namer(self):
|
|
586
|
+
"""Get _node_namer"""
|
|
587
|
+
return self._node_namer
|
|
588
|
+
|
|
589
|
+
def is_modified(self):
|
|
590
|
+
"""
|
|
591
|
+
Check whether symbol tree is modified.
|
|
592
|
+
|
|
593
|
+
Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
|
|
594
|
+
the symbol tree is created.
|
|
595
|
+
"""
|
|
596
|
+
return self._modified
|
|
597
|
+
|
|
598
|
+
def set_modified_true(self):
|
|
599
|
+
"""
|
|
600
|
+
Set self._modified true.
|
|
601
|
+
|
|
602
|
+
Self._modified is set true when 'if' exists in the original network.
|
|
603
|
+
In this situation, different original network instance tends to be different.
|
|
604
|
+
Hence, the class name should be updated.
|
|
605
|
+
"""
|
|
606
|
+
self._modified = True
|
|
607
|
+
|
|
608
|
+
def get_import_asts(self):
|
|
609
|
+
"""Get _import_asts"""
|
|
610
|
+
return self._import_asts
|
|
611
|
+
|
|
612
|
+
def get_external_ast(self):
|
|
613
|
+
"""Get _external_ast"""
|
|
614
|
+
return self._external_ast
|
|
615
|
+
|
|
616
|
+
def get_father_class_ast(self):
|
|
617
|
+
"""Get _father_class_ast"""
|
|
618
|
+
return self._father_class_ast
|
|
619
|
+
|
|
620
|
+
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
|
|
621
|
+
"""
|
|
622
|
+
Getter of inputs in topological relation of current 'node_or_name'.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
626
|
+
|
|
627
|
+
Returns:
|
|
628
|
+
A list of instances of Node as input nodes if 'node_or_name' belong to current SymbolTree. An empty list if
|
|
629
|
+
'node_or_name' not belong to current SymbolTree.
|
|
630
|
+
"""
|
|
631
|
+
|
|
632
|
+
real_node: Optional[Node] = self._get_real_node(node_or_name)
|
|
633
|
+
if real_node is None:
|
|
634
|
+
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
|
|
635
|
+
return []
|
|
636
|
+
return node_or_name.get_inputs()
|
|
637
|
+
|
|
638
|
+
def get_node_users(self, node_or_name: Union[Node, str]) -> [Tuple[Node, int]]:
|
|
639
|
+
"""
|
|
640
|
+
Getter of outputs in topological relation of current 'node_or_name'.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
A list of instances of Node as output nodes if 'node_or_name' belong to current SymbolTree. An empty list if
|
|
647
|
+
'node_or_name' not belong to current SymbolTree.
|
|
648
|
+
"""
|
|
649
|
+
|
|
650
|
+
real_node: Optional[Node] = self._get_real_node(node_or_name)
|
|
651
|
+
if real_node is None:
|
|
652
|
+
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
|
|
653
|
+
return []
|
|
654
|
+
if real_node.get_node_type() == NodeType.Output:
|
|
655
|
+
return []
|
|
656
|
+
node_users = []
|
|
657
|
+
for target_users in real_node.get_target_users().values():
|
|
658
|
+
if not target_users:
|
|
659
|
+
continue
|
|
660
|
+
if target_users not in node_users:
|
|
661
|
+
node_users.extend(target_users)
|
|
662
|
+
return node_users
|
|
663
|
+
|
|
664
|
+
def before(self, node_or_name: Union[Node, str]) -> Position:
|
|
665
|
+
"""
|
|
666
|
+
Get insert position before 'node_or_name' in source code list.
|
|
667
|
+
Consider using symbol_tree, node and before/after as position for sub-tree feature.
|
|
668
|
+
|
|
669
|
+
Note:
|
|
670
|
+
Topological order is not determined here which is determined by arguments of node and updated by
|
|
671
|
+
TopologicalManager automatically.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
675
|
+
|
|
676
|
+
Returns:
|
|
677
|
+
A Position represents an insert point.
|
|
678
|
+
|
|
679
|
+
Raises:
|
|
680
|
+
AssertError: If 'node_or_name' is not a Node or a str
|
|
681
|
+
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
|
|
682
|
+
SymbolTree.
|
|
683
|
+
"""
|
|
684
|
+
|
|
685
|
+
node = self._get_real_node(node_or_name)
|
|
686
|
+
if node is None:
|
|
687
|
+
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
688
|
+
return Position.create(node.get_belong_symbol_tree(), node, True)
|
|
689
|
+
|
|
690
|
+
def after(self, node_or_name: Union[Node, str]) -> Position:
|
|
691
|
+
"""
|
|
692
|
+
Get insert position after 'node_or_name' in source code list.
|
|
693
|
+
Consider using symbol_tree, node and before/after as position for sub-tree feature.
|
|
694
|
+
|
|
695
|
+
Note:
|
|
696
|
+
Topological order is not determined here which is determined by arguments of node and updated by
|
|
697
|
+
TopologicalManager automatically.
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
701
|
+
|
|
702
|
+
Returns:
|
|
703
|
+
A Position represents an insert point.
|
|
704
|
+
|
|
705
|
+
Raises:
|
|
706
|
+
AssertError: If 'node_or_name' is not a Node or a str
|
|
707
|
+
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
|
|
708
|
+
SymbolTree.
|
|
709
|
+
"""
|
|
710
|
+
node = self._get_real_node(node_or_name)
|
|
711
|
+
if node is None:
|
|
712
|
+
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
713
|
+
return Position.create(node.get_belong_symbol_tree(), node, False)
|
|
714
|
+
|
|
715
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
|
|
716
|
+
insert_to_ast: bool = True):
|
|
717
|
+
"""
|
|
718
|
+
Insert a node before or after base_node.
|
|
719
|
+
|
|
720
|
+
Note:
|
|
721
|
+
Name of node will be unique while inserting node into SymbolTree.
|
|
722
|
+
|
|
723
|
+
ValueType.CustomObjValue type arguments will be converted to ValueType.NamingValue and custom object will
|
|
724
|
+
be saved in global_vars dict while inserting node into SymbolTree.
|
|
725
|
+
|
|
726
|
+
Targets of node will be unique while inserting node into SymbolTree.
|
|
727
|
+
|
|
728
|
+
A field instantiation statement will be added into "init" function of network class using node name as field
|
|
729
|
+
name when `insert_to_ast` is True while inserting node into SymbolTree.
|
|
730
|
+
|
|
731
|
+
An assign statement represents invoking to this node will be added into forward function of network class
|
|
732
|
+
corresponding to field-instantiation-statement when `insert_to_ast` is True while inserting node into
|
|
733
|
+
SymbolTree.
|
|
734
|
+
|
|
735
|
+
Topological relation is updated and inputs of corresponding node is updated.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
new_node (Node): Node to be inserted.
|
|
739
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
740
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
741
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
742
|
+
NodeManager of symboltree's construct function.
|
|
743
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
An instance of node which has been inserted into SymbolTree.
|
|
747
|
+
|
|
748
|
+
Raises:
|
|
749
|
+
ValueError: Node in the SymbolTree is inserted into SymbolTree again.
|
|
750
|
+
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
|
751
|
+
"""
|
|
752
|
+
if new_node.get_belong_symbol_tree():
|
|
753
|
+
raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
|
|
754
|
+
|
|
755
|
+
# Check if base_node in current SymbolTree
|
|
756
|
+
if base_node is not None:
|
|
757
|
+
stree = base_node.get_belong_symbol_tree()
|
|
758
|
+
if stree is not None and stree is not self:
|
|
759
|
+
raise ValueError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
|
|
760
|
+
f"current: {self.get_ori_cls_name()}.")
|
|
761
|
+
|
|
762
|
+
# Check if node is inserted between Input node
|
|
763
|
+
if base_node is not None and base_node.get_node_type() == NodeType.Input:
|
|
764
|
+
valid = True
|
|
765
|
+
if before_node:
|
|
766
|
+
valid = False
|
|
767
|
+
if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
|
|
768
|
+
valid = False
|
|
769
|
+
if not valid:
|
|
770
|
+
raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
|
|
771
|
+
|
|
772
|
+
# save target name, which is used to provide unique target
|
|
773
|
+
if new_node.get_targets():
|
|
774
|
+
for target in new_node.get_targets():
|
|
775
|
+
self._target_namer.add_name(str(target))
|
|
776
|
+
|
|
777
|
+
self._handle_custom_obj_in_normalized_args(new_node)
|
|
778
|
+
|
|
779
|
+
# Insert node into NodeManager
|
|
780
|
+
if node_manager is None:
|
|
781
|
+
if base_node is None:
|
|
782
|
+
raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
|
|
783
|
+
node_manager = base_node.get_node_manager()
|
|
784
|
+
|
|
785
|
+
# set node's _belong_symbol_tree
|
|
786
|
+
new_node.set_belong_symbol_tree(self)
|
|
787
|
+
|
|
788
|
+
if node_manager is self:
|
|
789
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
790
|
+
if insert_to_ast:
|
|
791
|
+
# update init-function-ast and construct-function-ast
|
|
792
|
+
self.insert_to_ast_while_insert_node(new_node, base_node, before_node)
|
|
793
|
+
else:
|
|
794
|
+
node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
|
|
795
|
+
|
|
796
|
+
# register code changed event observer, which is used to update _modified flag.
|
|
797
|
+
if new_node.get_node_type() == NodeType.Tree:
|
|
798
|
+
new_node.symbol_tree.reg_observer(self)
|
|
799
|
+
elif isinstance(new_node, NodeManager):
|
|
800
|
+
new_node.reg_observer(self)
|
|
801
|
+
|
|
802
|
+
return new_node
|
|
803
|
+
|
|
804
|
+
def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
|
|
805
|
+
"""
|
|
806
|
+
Append a node to SymbolTree.
|
|
807
|
+
|
|
808
|
+
Args:
|
|
809
|
+
node (Node): An instance of node to be appended.
|
|
810
|
+
append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
|
811
|
+
True.
|
|
812
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
813
|
+
NodeManager of symboltree's construct function.
|
|
814
|
+
|
|
815
|
+
Returns:
|
|
816
|
+
An instance of node which has been appended to SymbolTree.
|
|
817
|
+
"""
|
|
818
|
+
if node_manager is None:
|
|
819
|
+
node_manager = self
|
|
820
|
+
return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
|
|
821
|
+
|
|
822
|
+
def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
|
|
823
|
+
"""
|
|
824
|
+
Append an original field node to SymbolTree. An original field node represents a node created from existing
|
|
825
|
+
statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
|
|
826
|
+
while these nodes appending to SymbolTree.
|
|
827
|
+
This method is called while building SymbolTree usually.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
node (Node): An instance of node to be appended.
|
|
831
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
832
|
+
NodeManager of symboltree's construct function.
|
|
833
|
+
|
|
834
|
+
Returns:
|
|
835
|
+
An instance of node which has been appended to SymbolTree.
|
|
836
|
+
"""
|
|
837
|
+
return self.append_node(node, node_manager, False)
|
|
838
|
+
|
|
839
|
+
def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
|
|
840
|
+
node_manager: NodeManager = None):
|
|
841
|
+
"""
|
|
842
|
+
Append an input node to SymbolTree corresponding to parameter of forward method of network class.
|
|
843
|
+
This method is called while building SymbolTree usually.
|
|
844
|
+
|
|
845
|
+
Args:
|
|
846
|
+
ast_node (ast.AST): A ast Node corresponding to current parameter.
|
|
847
|
+
param_name (str): A str represents name of parameter of forward method of network class.
|
|
848
|
+
default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
|
|
849
|
+
means parameter has no default value.
|
|
850
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
851
|
+
NodeManager of symboltree's construct function.
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
An instance of input node which has been appended to SymbolTree.
|
|
855
|
+
"""
|
|
856
|
+
if param_name == "self":
|
|
857
|
+
return
|
|
858
|
+
# check param_name duplicated
|
|
859
|
+
if node_manager is None:
|
|
860
|
+
node_manager = self
|
|
861
|
+
for input_node in node_manager.get_input_nodes():
|
|
862
|
+
targets = input_node.get_targets()
|
|
863
|
+
if len(targets) != 1:
|
|
864
|
+
raise RuntimeError("targets should have 1 elements")
|
|
865
|
+
target: ScopedValue = targets[0]
|
|
866
|
+
if target.type != ValueType.NamingValue:
|
|
867
|
+
raise RuntimeError("target.type should equal to ValueType.NamingValue")
|
|
868
|
+
if target.scope != "":
|
|
869
|
+
raise RuntimeError("target.scope should be empty")
|
|
870
|
+
exist_param = target.value
|
|
871
|
+
if exist_param == param_name:
|
|
872
|
+
raise RuntimeError("input duplicated:", param_name)
|
|
873
|
+
input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
|
|
874
|
+
self.append_origin_field(input_node, node_manager)
|
|
875
|
+
|
|
876
|
+
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
|
|
877
|
+
node_manager: NodeManager = None) -> Optional[Node]:
|
|
878
|
+
"""
|
|
879
|
+
Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
|
|
880
|
+
a list or a dict.
|
|
881
|
+
This method is called while building SymbolTree usually.
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
885
|
+
ast_node (ast.AST): A ast node represents ast node.
|
|
886
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
887
|
+
NodeManager of symboltree's construct function.
|
|
888
|
+
|
|
889
|
+
Returns:
|
|
890
|
+
An instance of python node if a new node has been appended to SymbolTree else None.
|
|
891
|
+
"""
|
|
892
|
+
if ast_node is None:
|
|
893
|
+
return None
|
|
894
|
+
if isinstance(ast_node, (list, dict)) and not ast_node:
|
|
895
|
+
return None
|
|
896
|
+
return self.append_python_node(ast_scope, ast_node, node_manager)
|
|
897
|
+
|
|
898
|
+
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
|
|
899
|
+
"""
|
|
900
|
+
Append a python node to SymbolTree.
|
|
901
|
+
This method is called while building SymbolTree usually.
|
|
902
|
+
|
|
903
|
+
Args:
|
|
904
|
+
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
905
|
+
ast_node (ast.AST): A ast node represents ast node.
|
|
906
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
907
|
+
NodeManager of symboltree's construct function.
|
|
908
|
+
|
|
909
|
+
Returns:
|
|
910
|
+
An instance of python node which has been appended to SymbolTree.
|
|
911
|
+
"""
|
|
912
|
+
logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
|
|
913
|
+
node_name = type(ast_node).__name__
|
|
914
|
+
node = Node.create_python_node(ast_node, node_name)
|
|
915
|
+
if node_manager is None or node_manager is self:
|
|
916
|
+
NodeManager.append_python_node(self, node)
|
|
917
|
+
else:
|
|
918
|
+
node_manager.append_python_node(node)
|
|
919
|
+
return node
|
|
920
|
+
|
|
921
|
+
def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
|
|
922
|
+
node_manager: NodeManager = None) -> Node:
|
|
923
|
+
"""
|
|
924
|
+
Update return value of return of forward method of network class.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
return_value (str): A str represents new return value.
|
|
928
|
+
arg_index (int): A int indicates which value in return to be updated.
|
|
929
|
+
return_idx (int): A int indicates which return node to be updated. Default: 0.
|
|
930
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
|
|
931
|
+
symboltree's construct function.
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
An instance of node represents return node after updated.
|
|
935
|
+
"""
|
|
936
|
+
node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
|
|
937
|
+
if not node_returns:
|
|
938
|
+
raise RuntimeError("Current node_manager has no output")
|
|
939
|
+
if return_idx >= len(node_returns):
|
|
940
|
+
raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
|
|
941
|
+
node_return = node_returns[return_idx]
|
|
942
|
+
self.set_node_arg(node_return, arg_index, return_value)
|
|
943
|
+
return node_return
|
|
944
|
+
|
|
945
|
+
def erase_node(self, node_or_name: Union[Node, str]) -> Node:
|
|
946
|
+
"""
|
|
947
|
+
Erase a node from SymbolTree.
|
|
948
|
+
|
|
949
|
+
Topological relation will be updated.
|
|
950
|
+
|
|
951
|
+
Args:
|
|
952
|
+
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
953
|
+
|
|
954
|
+
Returns:
|
|
955
|
+
An instance of node which has been erased from SymbolTree.
|
|
956
|
+
|
|
957
|
+
Raises:
|
|
958
|
+
RuntimeError: If 'node_or_name' is not in current SymbolTree.
|
|
959
|
+
RuntimeError: If erase corresponding ast node failed.
|
|
960
|
+
"""
|
|
961
|
+
|
|
962
|
+
node = self._get_real_node(node_or_name)
|
|
963
|
+
if node is None:
|
|
964
|
+
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
965
|
+
# erase node in NodeManager
|
|
966
|
+
node_manager = node.get_node_manager()
|
|
967
|
+
|
|
968
|
+
logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
|
|
969
|
+
f"node_manager: {node_manager.get_manager_name()}, "
|
|
970
|
+
f"code: {astunparse.unparse(node.get_ast()).strip()}, "
|
|
971
|
+
f"node_name:{node.get_name()}")
|
|
972
|
+
|
|
973
|
+
if node_manager is self:
|
|
974
|
+
NodeManager.erase_node(self, node)
|
|
975
|
+
if isinstance(node, ControlFlow):
|
|
976
|
+
ret = AstModifier.earse_ast_of_control_flow(self._root_ast.body, node.get_ast(), node.is_orelse)
|
|
977
|
+
else:
|
|
978
|
+
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
|
979
|
+
if not ret:
|
|
980
|
+
raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
|
|
981
|
+
else:
|
|
982
|
+
node_manager.erase_node(node)
|
|
983
|
+
node.set_belong_symbol_tree(None)
|
|
984
|
+
self._deleted_node.append(node.get_name())
|
|
985
|
+
return node
|
|
986
|
+
|
|
987
|
+
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
|
988
|
+
"""
|
|
989
|
+
Replace an old_node with a node list.
|
|
990
|
+
|
|
991
|
+
Args:
|
|
992
|
+
old_node (Node): Node to be replaced.
|
|
993
|
+
new_nodes (list[Node]): Node list to replace in.
|
|
994
|
+
|
|
995
|
+
Returns:
|
|
996
|
+
Last node in new_nodes list.
|
|
997
|
+
|
|
998
|
+
Raises:
|
|
999
|
+
RuntimeError: If 'old_node' is isolated.
|
|
1000
|
+
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
|
1001
|
+
"""
|
|
1002
|
+
real_old_node = self._get_real_node(old_node)
|
|
1003
|
+
if real_old_node is None:
|
|
1004
|
+
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
|
1005
|
+
# insert new_nodes into node_manager
|
|
1006
|
+
node_manager = real_old_node.get_node_manager()
|
|
1007
|
+
# insert new_nodes into NodeManager
|
|
1008
|
+
base_node = old_node
|
|
1009
|
+
for node in new_nodes:
|
|
1010
|
+
self.insert_node(node, base_node, False, node_manager, True)
|
|
1011
|
+
base_node = node
|
|
1012
|
+
self.erase_node(old_node)
|
|
1013
|
+
return new_nodes[-1]
|
|
1014
|
+
|
|
1015
|
+
def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
|
|
1016
|
+
"""
|
|
1017
|
+
Set argument of 'node'.
|
|
1018
|
+
|
|
1019
|
+
Args:
|
|
1020
|
+
node (Union[Node, str]): Node to be modified. Can be a node or name of node.
|
|
1021
|
+
index (int): Indicate which input being modified.
|
|
1022
|
+
arg (Union[ScopedValue, str]): New argument to been set.
|
|
1023
|
+
|
|
1024
|
+
Raises:
|
|
1025
|
+
RuntimeError: If 'node' is not belong to current SymbolTree.
|
|
1026
|
+
"""
|
|
1027
|
+
|
|
1028
|
+
real_node = self._get_real_node(node)
|
|
1029
|
+
if real_node is None:
|
|
1030
|
+
raise RuntimeError("Node is not belong to current SymbolTree: ", node)
|
|
1031
|
+
|
|
1032
|
+
new_arg, old_arg = node.set_arg(arg, index)
|
|
1033
|
+
node.get_node_manager().on_update_arg(node, index, old_arg, new_arg)
|
|
1034
|
+
|
|
1035
|
+
def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str],
|
|
1036
|
+
out_idx: Optional[int] = None):
|
|
1037
|
+
"""
|
|
1038
|
+
Set argument of 'dst_node' by another Node.
|
|
1039
|
+
|
|
1040
|
+
Args:
|
|
1041
|
+
dst_node (Node): Node to be modified. Can be a node or name of node.
|
|
1042
|
+
arg_idx (int): Indicate which input being modified.
|
|
1043
|
+
src_node (Node): Node as new input. Can be a node or name of node.
|
|
1044
|
+
out_idx ([int, optional]): Indicate which output of 'src_node' as new input of 'dst_node'. Default is None
|
|
1045
|
+
which means use first output of 'node_to_link' as new input.
|
|
1046
|
+
|
|
1047
|
+
Raises:
|
|
1048
|
+
RuntimeError: If 'dst_node' is not belong to current SymbolTree.
|
|
1049
|
+
RuntimeError: If 'src_node' is not belong to current SymbolTree.
|
|
1050
|
+
RuntimeError: If 'out_idx' is out of range.
|
|
1051
|
+
RuntimeError: If 'src_node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered.
|
|
1052
|
+
"""
|
|
1053
|
+
|
|
1054
|
+
real_dst_node = self._get_real_node(dst_node)
|
|
1055
|
+
if real_dst_node is None:
|
|
1056
|
+
raise RuntimeError("dst_node is not belong to current SymbolTree: ", dst_node)
|
|
1057
|
+
real_src_node = self._get_real_node(src_node)
|
|
1058
|
+
if real_src_node is None:
|
|
1059
|
+
raise RuntimeError("src_node is not belong to current SymbolTree: ", src_node)
|
|
1060
|
+
|
|
1061
|
+
targets = real_src_node.get_targets()
|
|
1062
|
+
if out_idx is None:
|
|
1063
|
+
if len(targets) != 1:
|
|
1064
|
+
raise RuntimeError("node should has one output when out_idx is not provided")
|
|
1065
|
+
out_idx = 0
|
|
1066
|
+
if out_idx >= len(targets):
|
|
1067
|
+
raise RuntimeError("out_idx out of range: ", out_idx)
|
|
1068
|
+
new_arg = targets[out_idx]
|
|
1069
|
+
real_dst_node.set_arg(new_arg, arg_idx)
|
|
1070
|
+
real_dst_node.get_node_manager().on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
|
|
1071
|
+
|
|
1072
|
+
def unique_name(self, name: str):
|
|
1073
|
+
"""Get a unique name in the symboltree"""
|
|
1074
|
+
return self._target_namer.get_name(name)
|
|
1075
|
+
|
|
1076
|
+
def unique_func_name(self, name: str):
|
|
1077
|
+
"""Get a unique function name in the symboltree"""
|
|
1078
|
+
if not hasattr(self._origin_network, name):
|
|
1079
|
+
return name
|
|
1080
|
+
suffix = 1
|
|
1081
|
+
while hasattr(self._origin_network, f"{name}_{suffix}"):
|
|
1082
|
+
suffix += 1
|
|
1083
|
+
return f"{name}_{suffix}"
|
|
1084
|
+
|
|
1085
|
+
def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
|
|
1086
|
+
"""
|
|
1087
|
+
Set target of `node` .
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
node (Union[Node, str]): Node to be modified. Can be a node or name of node.
|
|
1091
|
+
index (int): Indicate which target being modified.
|
|
1092
|
+
arg (Union[ScopedValue, str]): New target to been set.
|
|
1093
|
+
|
|
1094
|
+
Raises:
|
|
1095
|
+
ValueError: If `node` is not belong to current SymbolTree.
|
|
1096
|
+
ValueError: If index of `node` 's target is greater than number of targets.
|
|
1097
|
+
"""
|
|
1098
|
+
|
|
1099
|
+
real_node = self._get_real_node(node)
|
|
1100
|
+
if real_node is None:
|
|
1101
|
+
raise ValueError("Node is not belong to current SymbolTree: ", node)
|
|
1102
|
+
if isinstance(target, str):
|
|
1103
|
+
target = ScopedValue.create_naming_value(target)
|
|
1104
|
+
targets = node.get_targets()
|
|
1105
|
+
if index >= len(targets):
|
|
1106
|
+
raise ValueError(f"Index of node's target should be less than {len(targets)}, but got {index}")
|
|
1107
|
+
old_target = targets[index]
|
|
1108
|
+
targets[index] = target
|
|
1109
|
+
node.set_targets(targets)
|
|
1110
|
+
self._topo_mgr.on_update_target(node, index, old_target, target)
|
|
1111
|
+
|
|
1112
|
+
def all_nodes(self, subtree_nodes: bool = True):
|
|
1113
|
+
"""
|
|
1114
|
+
Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
|
|
1115
|
+
|
|
1116
|
+
Args:
|
|
1117
|
+
subtree_nodes (bool): Whether include nodes in subtree. Default: True.
|
|
1118
|
+
|
|
1119
|
+
Returns:
|
|
1120
|
+
A list of nodes.
|
|
1121
|
+
"""
|
|
1122
|
+
nodes = []
|
|
1123
|
+
node_managers = [self]
|
|
1124
|
+
while node_managers:
|
|
1125
|
+
node_manager = node_managers.pop()
|
|
1126
|
+
nodes.extend(node_manager.nodes())
|
|
1127
|
+
for node in node_manager.nodes():
|
|
1128
|
+
if isinstance(node, NodeManager):
|
|
1129
|
+
node_managers.append(node)
|
|
1130
|
+
if subtree_nodes:
|
|
1131
|
+
for tree_node in self.get_tree_nodes():
|
|
1132
|
+
stree = tree_node.symbol_tree
|
|
1133
|
+
nodes.extend(stree.all_nodes())
|
|
1134
|
+
return nodes
|
|
1135
|
+
|
|
1136
|
+
def get_node_from_name(self, node_name: str):
|
|
1137
|
+
"""
|
|
1138
|
+
Get node from all NodeManagers in current symbol tree by `node_name`.
|
|
1139
|
+
|
|
1140
|
+
Args:
|
|
1141
|
+
node_name (str): A str represents name of node as key of query.
|
|
1142
|
+
|
|
1143
|
+
Returns:
|
|
1144
|
+
An instance of Node if found else None.
|
|
1145
|
+
"""
|
|
1146
|
+
node_managers = [self]
|
|
1147
|
+
while node_managers:
|
|
1148
|
+
node_manager = node_managers.pop()
|
|
1149
|
+
node = node_manager.get_node(node_name)
|
|
1150
|
+
if node:
|
|
1151
|
+
return node
|
|
1152
|
+
for node in node_manager.nodes():
|
|
1153
|
+
if isinstance(node, NodeManager):
|
|
1154
|
+
node_managers.append(node)
|
|
1155
|
+
return None
|
|
1156
|
+
|
|
1157
|
+
def get_node_tabulate(self, all_nodes: bool = False) -> str:
|
|
1158
|
+
"""
|
|
1159
|
+
Get nodes information and nodes' topological relations.
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
|
|
1163
|
+
nodes, CellContainer nodes and sub symbol trees.
|
|
1164
|
+
|
|
1165
|
+
Returns:
|
|
1166
|
+
String of nodes' information and topological relations.
|
|
1167
|
+
"""
|
|
1168
|
+
try:
|
|
1169
|
+
from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
|
|
1170
|
+
except ImportError:
|
|
1171
|
+
logger.warning("print_node_tabulate relies on the library `tabulate`, "
|
|
1172
|
+
"which could not be found on this machine. Run `pip "
|
|
1173
|
+
"install tabulate` to install the library.")
|
|
1174
|
+
return ""
|
|
1175
|
+
dump_str = NodeManager.dump(self, self.get_manager_name())
|
|
1176
|
+
if all_nodes:
|
|
1177
|
+
node_managers = [self]
|
|
1178
|
+
while node_managers:
|
|
1179
|
+
node_manager = node_managers.pop()
|
|
1180
|
+
for node in node_manager.nodes():
|
|
1181
|
+
if isinstance(node, NodeManager):
|
|
1182
|
+
dump_str += node.dump(SymbolTree.get_node_full_name(node))
|
|
1183
|
+
node_managers.append(node)
|
|
1184
|
+
for tree_node in self.get_tree_nodes():
|
|
1185
|
+
stree = tree_node.symbol_tree
|
|
1186
|
+
dump_str += stree.get_node_tabulate(all_nodes)
|
|
1187
|
+
return dump_str
|
|
1188
|
+
|
|
1189
|
+
def dump(self):
|
|
1190
|
+
"""Dump graph."""
|
|
1191
|
+
dump_st = SymbolTreeDumper(self)
|
|
1192
|
+
dump_st.dump()
|
|
1193
|
+
|
|
1194
|
+
def check_body_exist(self, body, code_bodies):
|
|
1195
|
+
"""Check whether body already exist in code_bodies"""
|
|
1196
|
+
# Check import ast node exist by saving import code string to self._tmp_import_strs
|
|
1197
|
+
if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
|
|
1198
|
+
import_str = astunparse.unparse(body)
|
|
1199
|
+
if import_str in self._tmp_import_strs:
|
|
1200
|
+
return True
|
|
1201
|
+
self._tmp_import_strs.append(import_str)
|
|
1202
|
+
return False
|
|
1203
|
+
|
|
1204
|
+
# Check ClassDef ast node exist by using AstClassFinder
|
|
1205
|
+
if isinstance(body, ast.ClassDef):
|
|
1206
|
+
if sys.version_info >= (3, 9):
|
|
1207
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1208
|
+
else:
|
|
1209
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies))
|
|
1210
|
+
results = class_finder.find_all(body.name)
|
|
1211
|
+
return bool(results)
|
|
1212
|
+
|
|
1213
|
+
# Check FunctionDef ast node exist by using AstFunctionFinder
|
|
1214
|
+
if isinstance(body, ast.FunctionDef):
|
|
1215
|
+
if sys.version_info >= (3, 9):
|
|
1216
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1217
|
+
else:
|
|
1218
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
|
|
1219
|
+
results = function_finder.find_all(body.name)
|
|
1220
|
+
return bool(results)
|
|
1221
|
+
|
|
1222
|
+
return False
|
|
1223
|
+
|
|
1224
|
+
def deduplicate_unmodified_stree(self, code_bodies):
|
|
1225
|
+
"""
|
|
1226
|
+
Init function may be different even if stree is not modified manually, when subnets in stree is
|
|
1227
|
+
initialized by different arguments.
|
|
1228
|
+
In this case, we need to wait for code_bodies being fully generated, so that the name of subnets
|
|
1229
|
+
will be updated, then we can deduplicate again according to ast of init function.
|
|
1230
|
+
"""
|
|
1231
|
+
# prepare AstClassFinder and AstReplacer
|
|
1232
|
+
if sys.version_info >= (3, 9):
|
|
1233
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1234
|
+
name_replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1235
|
+
else:
|
|
1236
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies))
|
|
1237
|
+
name_replacer = AstReplacer(ast.Module(body=code_bodies))
|
|
1238
|
+
# deduplicate all unmodified strees in self._tmp_unmodified_strees
|
|
1239
|
+
deduplicated = False
|
|
1240
|
+
for _, unmodified_strees in self._tmp_unmodified_strees.items():
|
|
1241
|
+
if len(unmodified_strees) <= 1:
|
|
1242
|
+
continue
|
|
1243
|
+
init_func_codes = [astunparse.unparse(stree.get_init_func_ast()) for stree in unmodified_strees]
|
|
1244
|
+
# If the index of an element is not its own, it means that it is a duplicate element
|
|
1245
|
+
to_be_erase = []
|
|
1246
|
+
for idx, code in enumerate(init_func_codes):
|
|
1247
|
+
first_idx = init_func_codes.index(code)
|
|
1248
|
+
if first_idx != idx:
|
|
1249
|
+
first_stree_cls_name = unmodified_strees[first_idx].get_opt_cls_name()
|
|
1250
|
+
duplicated_stree_cls_name = unmodified_strees[idx].get_opt_cls_name()
|
|
1251
|
+
logger.debug(f"replace stree:{duplicated_stree_cls_name} to {first_stree_cls_name}.")
|
|
1252
|
+
# delete duplicated class from code_bodies
|
|
1253
|
+
results = class_finder.find_all(duplicated_stree_cls_name)
|
|
1254
|
+
for ast_cls in results:
|
|
1255
|
+
code_bodies.remove(ast_cls)
|
|
1256
|
+
# replace name of duplicated class in code_bodies to first_stree_cls_name
|
|
1257
|
+
name_replacer.replace_all(duplicated_stree_cls_name, first_stree_cls_name)
|
|
1258
|
+
# record deduplicated stree
|
|
1259
|
+
to_be_erase.append(idx)
|
|
1260
|
+
deduplicated = True
|
|
1261
|
+
# remove class in self._tmp_unmodified_strees
|
|
1262
|
+
for idx in reversed(to_be_erase):
|
|
1263
|
+
unmodified_strees.pop(idx)
|
|
1264
|
+
|
|
1265
|
+
# the name of subnets is updated, so we need to deduplicate again.
|
|
1266
|
+
if deduplicated:
|
|
1267
|
+
self._tmp_replacers.append(name_replacer)
|
|
1268
|
+
self.deduplicate_unmodified_stree(code_bodies)
|
|
1269
|
+
|
|
1270
|
+
def update_unmodified_stree(self, stree, code_bodies) -> bool:
|
|
1271
|
+
"""
|
|
1272
|
+
For the unmodified symbol tree, only one definition code remains in the generated code.
|
|
1273
|
+
Everywhere else calling this symbol tree will use the class in this definition code.
|
|
1274
|
+
"""
|
|
1275
|
+
# all modified ast.ClassDef will be exported to code
|
|
1276
|
+
if stree.is_modified():
|
|
1277
|
+
logger.debug(f"stree:{stree.get_opt_cls_name()} is modified.")
|
|
1278
|
+
return False
|
|
1279
|
+
# all un-modified ast.ClassDef only keep one instance
|
|
1280
|
+
unmodified_strees = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
|
|
1281
|
+
if not unmodified_strees:
|
|
1282
|
+
self._tmp_unmodified_strees[type(stree.get_origin_network())] = [stree]
|
|
1283
|
+
logger.debug(f"stree:{stree.get_opt_cls_name()} is the first stree.")
|
|
1284
|
+
return False
|
|
1285
|
+
# Init function may be different even if stree is not modified, when subnets in stree is
|
|
1286
|
+
# initialized by different arguments.
|
|
1287
|
+
first_stree = unmodified_strees[0]
|
|
1288
|
+
first_stree_cls_name = first_stree.get_opt_cls_name()
|
|
1289
|
+
if astunparse.unparse(stree.get_init_func_ast()) != astunparse.unparse(first_stree.get_init_func_ast()):
|
|
1290
|
+
# init ast may be updated after inserting subtrees of stree, so we need to save unmodified strees
|
|
1291
|
+
# and deduplicate later
|
|
1292
|
+
self._tmp_unmodified_strees[type(stree.get_origin_network())].append(stree)
|
|
1293
|
+
logger.debug(f"init func different, stree:{stree.get_opt_cls_name()}, first_stree:{first_stree_cls_name}.")
|
|
1294
|
+
return False
|
|
1295
|
+
# Un-modified ast.ClassDef already exist in code_bodies,
|
|
1296
|
+
# replace class name to class name of first un-modified ast.ClassDef.
|
|
1297
|
+
if sys.version_info >= (3, 9):
|
|
1298
|
+
replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1299
|
+
else:
|
|
1300
|
+
replacer = AstReplacer(ast.Module(body=code_bodies))
|
|
1301
|
+
logger.debug(f"replace stree:{stree.get_opt_cls_name()} to {first_stree_cls_name}.")
|
|
1302
|
+
replacer.replace_all(stree.get_class_ast().name, first_stree_cls_name)
|
|
1303
|
+
self._tmp_replacers.append(replacer)
|
|
1304
|
+
return True
|
|
1305
|
+
|
|
1306
|
+
def init_code_bodies(self, code_bodies: list) -> int:
|
|
1307
|
+
"""Init code bodied"""
|
|
1308
|
+
# Add basic imports
|
|
1309
|
+
code_bodies.append(ast.Import([ast.alias(name='sys', asname=None)]))
|
|
1310
|
+
code_bodies.append(ast.Import([ast.alias(name='mindspore', asname=None)]))
|
|
1311
|
+
code_bodies.append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
|
|
1312
|
+
code_bodies.append(ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)], level=0))
|
|
1313
|
+
code_bodies.append(ast.ImportFrom(module='mindspore.ops',
|
|
1314
|
+
names=[ast.alias(name='functional', asname='F')], level=0))
|
|
1315
|
+
code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
|
|
1316
|
+
# Add user custom codes into code_bodies
|
|
1317
|
+
custom_codes = self.get_custom_codes()
|
|
1318
|
+
for code_ast in custom_codes:
|
|
1319
|
+
code_bodies.append(code_ast)
|
|
1320
|
+
code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
|
|
1321
|
+
return len(code_bodies)
|
|
1322
|
+
|
|
1323
|
+
def convert_stree_to_code_bodies(self, stree: 'SymbolTree', code_bodies: list, dividing_pos=0) -> int:
|
|
1324
|
+
"""
|
|
1325
|
+
Convert nodes in stree to code_bodies
|
|
1326
|
+
- Add external function asts into code_bodies
|
|
1327
|
+
- Add father class asts into code_bodies
|
|
1328
|
+
- Add import asts of symbol tree into code_bodies
|
|
1329
|
+
- Add user custom codes into code_bodies
|
|
1330
|
+
- Add class asts of symbol tree into code_bodies
|
|
1331
|
+
- Add subtrees to code_bodies
|
|
1332
|
+
"""
|
|
1333
|
+
insert_pos = dividing_pos
|
|
1334
|
+
# Add external asts into code_bodies
|
|
1335
|
+
for ast_func, import_asts in reversed(stree.get_external_ast().items()):
|
|
1336
|
+
if self.check_body_exist(ast_func, code_bodies):
|
|
1337
|
+
continue
|
|
1338
|
+
# add imports of external_ast
|
|
1339
|
+
self._tmp_import_strs.clear()
|
|
1340
|
+
for ast_import in import_asts:
|
|
1341
|
+
if not self.check_body_exist(ast_import, code_bodies):
|
|
1342
|
+
code_bodies.insert(insert_pos, ast_import)
|
|
1343
|
+
insert_pos += 1
|
|
1344
|
+
# add external_ast
|
|
1345
|
+
code_bodies.insert(insert_pos, ast_func)
|
|
1346
|
+
insert_pos += 1
|
|
1347
|
+
# add divide
|
|
1348
|
+
code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
|
|
1349
|
+
insert_pos += 1
|
|
1350
|
+
|
|
1351
|
+
# Add father class asts into code_bodies
|
|
1352
|
+
for ast_class, import_asts in stree.get_father_class_ast().items():
|
|
1353
|
+
if self.check_body_exist(ast_class, code_bodies):
|
|
1354
|
+
continue
|
|
1355
|
+
# add imports of father class
|
|
1356
|
+
self._tmp_import_strs.clear()
|
|
1357
|
+
for ast_import in import_asts:
|
|
1358
|
+
if not self.check_body_exist(ast_import, code_bodies):
|
|
1359
|
+
code_bodies.insert(insert_pos, ast_import)
|
|
1360
|
+
insert_pos += 1
|
|
1361
|
+
# add ast of father class
|
|
1362
|
+
code_bodies.insert(insert_pos, ast_class)
|
|
1363
|
+
insert_pos += 1
|
|
1364
|
+
# add divide
|
|
1365
|
+
code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
|
|
1366
|
+
insert_pos += 1
|
|
1367
|
+
|
|
1368
|
+
# external functions and father class are above the dividing_pos to support deduplication.
|
|
1369
|
+
dividing_pos = insert_pos
|
|
1370
|
+
|
|
1371
|
+
# Add import asts of symbol tree into code_bodies
|
|
1372
|
+
self._tmp_import_strs.clear()
|
|
1373
|
+
for body in stree.get_import_asts():
|
|
1374
|
+
if not self.check_body_exist(body, code_bodies):
|
|
1375
|
+
code_bodies.insert(insert_pos, body)
|
|
1376
|
+
insert_pos += 1
|
|
1377
|
+
|
|
1378
|
+
# Add class asts of symbol tree into code_bodies
|
|
1379
|
+
if stree.get_module_ast():
|
|
1380
|
+
for body in stree.get_module_ast().body:
|
|
1381
|
+
if self.check_body_exist(body, code_bodies):
|
|
1382
|
+
continue
|
|
1383
|
+
code_bodies.insert(insert_pos, body)
|
|
1384
|
+
insert_pos += 1
|
|
1385
|
+
|
|
1386
|
+
# add divide
|
|
1387
|
+
code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
|
|
1388
|
+
insert_pos += 1
|
|
1389
|
+
|
|
1390
|
+
# Add subtrees to code_bodies
|
|
1391
|
+
for node in stree.get_tree_nodes():
|
|
1392
|
+
sub_stree = node.symbol_tree
|
|
1393
|
+
# For the unmodified class, update class name to name of first class
|
|
1394
|
+
if self.update_unmodified_stree(sub_stree, code_bodies):
|
|
1395
|
+
continue
|
|
1396
|
+
dividing_pos = self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, dividing_pos)
|
|
1397
|
+
|
|
1398
|
+
# return new dividing position
|
|
1399
|
+
return dividing_pos
|
|
1400
|
+
|
|
1401
|
+
def get_code(self) -> str:
|
|
1402
|
+
"""
|
|
1403
|
+
Get source code of modified network.
|
|
1404
|
+
|
|
1405
|
+
Returns:
|
|
1406
|
+
A str represents source code of modified network.
|
|
1407
|
+
"""
|
|
1408
|
+
self._tmp_import_strs.clear()
|
|
1409
|
+
self._tmp_unmodified_strees.clear()
|
|
1410
|
+
self._tmp_replacers.clear()
|
|
1411
|
+
code_bodies = []
|
|
1412
|
+
begin_pos = self.init_code_bodies(code_bodies)
|
|
1413
|
+
self.convert_stree_to_code_bodies(self, code_bodies, begin_pos)
|
|
1414
|
+
self.deduplicate_unmodified_stree(code_bodies)
|
|
1415
|
+
if sys.version_info >= (3, 9):
|
|
1416
|
+
gencode_module = ast.Module(body=code_bodies, type_ignores=[])
|
|
1417
|
+
else:
|
|
1418
|
+
gencode_module = ast.Module(body=code_bodies)
|
|
1419
|
+
SymbolTree._remove_unused_import(gencode_module)
|
|
1420
|
+
self._process_duplicate_name_modules(gencode_module)
|
|
1421
|
+
SymbolTree._remove_duplicated_import(gencode_module)
|
|
1422
|
+
SymbolTree._remove_arg_annotations(gencode_module)
|
|
1423
|
+
ast.fix_missing_locations(self._module_ast)
|
|
1424
|
+
code = astunparse.unparse(gencode_module)
|
|
1425
|
+
# Revert the class name to its original state
|
|
1426
|
+
for replacer in self._tmp_replacers:
|
|
1427
|
+
replacer.undo_all()
|
|
1428
|
+
return code
|
|
1429
|
+
|
|
1430
|
+
def get_network(self):
|
|
1431
|
+
"""
|
|
1432
|
+
Get modified network.
|
|
1433
|
+
|
|
1434
|
+
Returns:
|
|
1435
|
+
A network object.
|
|
1436
|
+
"""
|
|
1437
|
+
cls = self._get_cls_through_file()
|
|
1438
|
+
new_net = cls(self._origin_network)
|
|
1439
|
+
self._merge_origin_property(new_net)
|
|
1440
|
+
# update parameters' names to fix duplicated names bug
|
|
1441
|
+
# which occurs after inserting cell to celllist/sequentialcell
|
|
1442
|
+
new_net.update_parameters_name()
|
|
1443
|
+
return new_net
|
|
1444
|
+
|
|
1445
|
+
def set_saved_file_name(self, file_name: str):
|
|
1446
|
+
if file_name.endswith(".py"):
|
|
1447
|
+
self._saved_file_name = file_name
|
|
1448
|
+
else:
|
|
1449
|
+
self._saved_file_name = file_name + ".py"
|
|
1450
|
+
|
|
1451
|
+
def get_saved_file_name(self):
|
|
1452
|
+
return self._saved_file_name
|
|
1453
|
+
|
|
1454
|
+
def save_network_to_file(self):
|
|
1455
|
+
abs_path = os.path.realpath(self._saved_file_name)
|
|
1456
|
+
if os.path.isfile(abs_path):
|
|
1457
|
+
os.remove(abs_path)
|
|
1458
|
+
with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
|
1459
|
+
source = self.get_code()
|
|
1460
|
+
f.write(source.encode('utf-8'))
|
|
1461
|
+
f.flush()
|
|
1462
|
+
|
|
1463
|
+
|
|
1464
|
+
def flatten_nodes(self, node, erase_another_branch: bool = False, erase_nodes_after_return: bool = False):
|
|
1465
|
+
"""Flatten nodes in ControlFlow node."""
|
|
1466
|
+
if not isinstance(node, ControlFlow):
|
|
1467
|
+
raise ValueError(f"For flatten_nodes, the type of node can only be ControlFlow, but got {type(node)}.")
|
|
1468
|
+
upper_node_manager = node.get_node_manager()
|
|
1469
|
+
if isinstance(upper_node_manager, (SymbolTree, CallFunction)):
|
|
1470
|
+
ast_bodies = upper_node_manager.get_manager_ast().body
|
|
1471
|
+
elif isinstance(upper_node_manager, ControlFlow):
|
|
1472
|
+
ast_bodies = upper_node_manager.get_manager_ast()
|
|
1473
|
+
else:
|
|
1474
|
+
raise ValueError("For flatten_nodes, the node can only be contained in [SymbolTree, CallFunction, "
|
|
1475
|
+
f"ControlFlow], but the node is in {type(upper_node_manager)}.")
|
|
1476
|
+
base_node = node.orelse_node if node.orelse_node else node.body_node
|
|
1477
|
+
for n in node.nodes()[:]:
|
|
1478
|
+
self.erase_node(n)
|
|
1479
|
+
self.insert_node(n, base_node, False, upper_node_manager, False)
|
|
1480
|
+
AstModifier.insert_ast_to_bodies(ast_bodies, n.get_ast(), base_node.get_ast(), False)
|
|
1481
|
+
base_node = n
|
|
1482
|
+
self.erase_node(node)
|
|
1483
|
+
# remove another branch
|
|
1484
|
+
if erase_another_branch:
|
|
1485
|
+
if node.is_orelse:
|
|
1486
|
+
self.erase_node(node.body_node)
|
|
1487
|
+
elif node.orelse_node is not None:
|
|
1488
|
+
self.erase_node(node.orelse_node)
|
|
1489
|
+
# remove nodes after return node
|
|
1490
|
+
if erase_nodes_after_return:
|
|
1491
|
+
has_return = False
|
|
1492
|
+
for n in upper_node_manager.nodes():
|
|
1493
|
+
if has_return:
|
|
1494
|
+
logger.warning(f"Node {n.get_name()} which is behind the flatten return node is "
|
|
1495
|
+
f"automatically erased.")
|
|
1496
|
+
self.erase_node(n)
|
|
1497
|
+
elif n.get_node_type() == NodeType.Output:
|
|
1498
|
+
has_return = True
|
|
1499
|
+
|
|
1500
|
+
def eval_ast_result(self, ast_node: ast.AST) -> (bool, bool):
|
|
1501
|
+
"""
|
|
1502
|
+
Eval ast_node and get result, only used in control flow node.
|
|
1503
|
+
"""
|
|
1504
|
+
# ast.Constant can be check without eval
|
|
1505
|
+
if isinstance(ast_node, ast.Constant):
|
|
1506
|
+
return True, bool(ast.value)
|
|
1507
|
+
# Get the module where the code of ast_node is located
|
|
1508
|
+
file_path = inspect.getfile(type(self.get_origin_network()))
|
|
1509
|
+
module = None
|
|
1510
|
+
for m in list(sys.modules.values()):
|
|
1511
|
+
if hasattr(m, "__file__") and m.__file__ and os.path.normcase(m.__file__) == os.path.normcase(file_path):
|
|
1512
|
+
module = m
|
|
1513
|
+
break
|
|
1514
|
+
if not module:
|
|
1515
|
+
logger.warning("Failed to get module of ast_node.")
|
|
1516
|
+
return False, False
|
|
1517
|
+
# eval ast_node and get result
|
|
1518
|
+
logger.debug(f"Eval ast node: {astunparse.unparse(ast_node)}")
|
|
1519
|
+
ast_expr = ast.Expression(ast_node)
|
|
1520
|
+
ast_expr = ast.fix_missing_locations(ast_expr)
|
|
1521
|
+
try:
|
|
1522
|
+
# eval with ast make this operation free of instruction injection
|
|
1523
|
+
# pylint: disable=eval-used
|
|
1524
|
+
result = eval(compile(ast_expr, "eval_ast_result", "eval"), {**globals(), **module.__dict__}, locals())
|
|
1525
|
+
except Exception as e: # pylint: disable=broad-except
|
|
1526
|
+
logger.debug(f"Cannot get result of ast_node by eval, err:{e}")
|
|
1527
|
+
return False, False
|
|
1528
|
+
logger.debug(f"Eval ast result success, result: {result}")
|
|
1529
|
+
return True, bool(result)
|
|
1530
|
+
|
|
1531
|
+
def flatten_static_if_control_flow(self):
|
|
1532
|
+
"""
|
|
1533
|
+
For static if control flow, flatten codes in branch which will be executed and erase another branch.
|
|
1534
|
+
"""
|
|
1535
|
+
for node in self.all_nodes()[:]:
|
|
1536
|
+
if not node.get_belong_symbol_tree():
|
|
1537
|
+
# the node has been erased
|
|
1538
|
+
continue
|
|
1539
|
+
if isinstance(node, ControlFlow) and node.test_result is not None:
|
|
1540
|
+
stree = node.get_belong_symbol_tree()
|
|
1541
|
+
if node.test_result:
|
|
1542
|
+
stree.flatten_nodes(node.body_node, True, True)
|
|
1543
|
+
else:
|
|
1544
|
+
if node.orelse_node is not None:
|
|
1545
|
+
stree.flatten_nodes(node.orelse_node, True, True)
|
|
1546
|
+
else:
|
|
1547
|
+
stree.erase_node(node.body_node)
|
|
1548
|
+
|
|
1549
|
+
def add_custom_codes(self, code: str):
|
|
1550
|
+
"""Add user custom codes"""
|
|
1551
|
+
code_ast = ast.parse(code)
|
|
1552
|
+
self._custom_codes.extend(code_ast.body)
|
|
1553
|
+
|
|
1554
|
+
def get_custom_codes(self) -> List[ast.AST]:
|
|
1555
|
+
"""Add user custom codes"""
|
|
1556
|
+
return self._custom_codes
|
|
1557
|
+
|
|
1558
|
+
def save_file_path_to_sys(self, level_num, file_path, belonging_ast: ast.AST = None):
|
|
1559
|
+
"""
|
|
1560
|
+
Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom.
|
|
1561
|
+
|
|
1562
|
+
When level_num = 0(e.g. from xxx import yyy), current path will be saved.
|
|
1563
|
+
When level_num = 1(e.g. from .xxx import yyy), current path will be saved.
|
|
1564
|
+
When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved.
|
|
1565
|
+
"""
|
|
1566
|
+
file_path = os.path.dirname(os.path.realpath(file_path))
|
|
1567
|
+
file_path = os.path.normcase(file_path)
|
|
1568
|
+
file_path = os.path.normpath(file_path)
|
|
1569
|
+
if level_num > 1:
|
|
1570
|
+
for _ in range(level_num - 1):
|
|
1571
|
+
file_path = os.path.dirname(file_path)
|
|
1572
|
+
sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0]
|
|
1573
|
+
# add imports to import_asts of belonging_ast
|
|
1574
|
+
import_asts = self._get_imports_list_of_ast(belonging_ast)
|
|
1575
|
+
import_asts.append(ast.Import([ast.alias(name='sys', asname=None)]))
|
|
1576
|
+
import_asts.append(sys_path_append_ast)
|
|
1577
|
+
|
|
1578
|
+
def save_imports_from_file(self, file_path, belonging_ast: ast.AST = None):
|
|
1579
|
+
"""Save imports from file"""
|
|
1580
|
+
self.save_file_path_to_sys(0, file_path, belonging_ast)
|
|
1581
|
+
if not os.path.exists(file_path):
|
|
1582
|
+
raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.")
|
|
1583
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
1584
|
+
source_code = f.read()
|
|
1585
|
+
import_nodes = AstImportFinder(ast.parse(dedent(source_code))).get_import_node()
|
|
1586
|
+
if not import_nodes:
|
|
1587
|
+
return
|
|
1588
|
+
# add imports to import_asts of belonging_ast
|
|
1589
|
+
import_asts = self._get_imports_list_of_ast(belonging_ast)
|
|
1590
|
+
for import_node in import_nodes:
|
|
1591
|
+
import_node = SymbolTree._process_relative_import(import_node, file_path)
|
|
1592
|
+
if import_node:
|
|
1593
|
+
import_asts.append(import_node)
|
|
1594
|
+
|
|
1595
|
+
def add_import(self, module: types.ModuleType, name: str, belonging_ast: None):
|
|
1596
|
+
"""add codes: from `module` import `name`"""
|
|
1597
|
+
if not isinstance(module, types.ModuleType):
|
|
1598
|
+
raise TypeError(f"For add_import, module should be ModuleType, but got {type(module)}")
|
|
1599
|
+
if not hasattr(module, name):
|
|
1600
|
+
logger.info(f"module {module.__name__} doesn't have attr '{name}', it may be a local variable.")
|
|
1601
|
+
return
|
|
1602
|
+
# add imports to import_asts of belonging_ast
|
|
1603
|
+
import_asts = self._get_imports_list_of_ast(belonging_ast)
|
|
1604
|
+
if module.__name__ == "__main__":
|
|
1605
|
+
# get attr from module instead of import to avoid duplicate execution of __main__ module
|
|
1606
|
+
code = f"{name} = getattr(sys.modules['__main__'], '{name}')"
|
|
1607
|
+
code_ast = ast.parse(code).body[0]
|
|
1608
|
+
import_asts.append(code_ast)
|
|
1609
|
+
elif module.__name__ == "builtins":
|
|
1610
|
+
# built-in functions are not need to be imported
|
|
1611
|
+
pass
|
|
1612
|
+
else:
|
|
1613
|
+
# add import of obj to ast
|
|
1614
|
+
func_file_path = inspect.getabsfile(module)
|
|
1615
|
+
func_file_path = os.path.normcase(func_file_path)
|
|
1616
|
+
prefix_paths = []
|
|
1617
|
+
for path in sys.path:
|
|
1618
|
+
path = os.path.normcase(path)
|
|
1619
|
+
if func_file_path.startswith(path):
|
|
1620
|
+
prefix_paths.append(path)
|
|
1621
|
+
prefix_paths.sort(key=len, reverse=True)
|
|
1622
|
+
for path in prefix_paths:
|
|
1623
|
+
import_path = func_file_path[len(path):]
|
|
1624
|
+
import_str = import_path.replace(os.path.sep, '.')
|
|
1625
|
+
import_str = import_str[1:] # remove first '.'
|
|
1626
|
+
mod = import_str.rsplit('.', 1)[0]
|
|
1627
|
+
if SymbolTree._check_import(func_file_path[:len(path)], mod):
|
|
1628
|
+
import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
|
|
1629
|
+
import_asts.append(import_node)
|
|
1630
|
+
break
|
|
1631
|
+
else:
|
|
1632
|
+
self.save_file_path_to_sys(0, func_file_path, belonging_ast)
|
|
1633
|
+
mod = os.path.basename(func_file_path).rsplit('.')[0]
|
|
1634
|
+
import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
|
|
1635
|
+
import_asts.append(import_node)
|
|
1636
|
+
|
|
1637
|
+
def _get_imports_list_of_ast(self, belonging_ast: ast.AST):
|
|
1638
|
+
# get import_asts of belonging_ast
|
|
1639
|
+
import_asts = self._import_asts
|
|
1640
|
+
if belonging_ast is not None:
|
|
1641
|
+
if belonging_ast in self._father_class_ast:
|
|
1642
|
+
import_asts = self._father_class_ast.get(belonging_ast)
|
|
1643
|
+
elif belonging_ast in self._external_ast:
|
|
1644
|
+
import_asts = self._external_ast.get(belonging_ast)
|
|
1645
|
+
return import_asts
|
|
1646
|
+
|
|
1647
|
+
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
|
1648
|
+
if isinstance(node_or_name, str):
|
|
1649
|
+
return self.get_node(node_or_name)
|
|
1650
|
+
return node_or_name
|
|
1651
|
+
|
|
1652
|
+
def _handle_custom_obj_in_normalized_args(self, node: Node):
|
|
1653
|
+
"""
|
|
1654
|
+
Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
|
|
1655
|
+
|
|
1656
|
+
Args:
|
|
1657
|
+
node (Node): A Node whose arguments and keyword arguments to be handled.
|
|
1658
|
+
"""
|
|
1659
|
+
normalized_args: {str, ScopedValue} = {}
|
|
1660
|
+
for key, value in node.get_normalized_args().items():
|
|
1661
|
+
if not isinstance(value, ScopedValue):
|
|
1662
|
+
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
1663
|
+
if value.type == ValueType.CustomObjValue:
|
|
1664
|
+
# Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
|
|
1665
|
+
arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
|
|
1666
|
+
setattr(self._origin_network, arg_name, value.value)
|
|
1667
|
+
# Add new code to __init__(): self.arg_name = obj.arg_name
|
|
1668
|
+
new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
|
|
1669
|
+
self._init_func_ast.body.append(new_ast)
|
|
1670
|
+
# Modify node's normalized_args: CustomObjValue -> self.arg_name
|
|
1671
|
+
normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
|
|
1672
|
+
else:
|
|
1673
|
+
normalized_args[key] = value
|
|
1674
|
+
node.set_normalized_args(normalized_args)
|
|
1675
|
+
|
|
1676
|
+
def _get_cls_through_file(self):
|
|
1677
|
+
"""
|
|
1678
|
+
Load rewritten network class of current SymbolTree.
|
|
1679
|
+
1. Get source code of current SymbolTree.
|
|
1680
|
+
2. Saving source code to a tempfile.
|
|
1681
|
+
3. Import rewritten network class using "__import__" function.
|
|
1682
|
+
|
|
1683
|
+
Returns:
|
|
1684
|
+
A class handle.
|
|
1685
|
+
"""
|
|
1686
|
+
file_path = os.getcwd()
|
|
1687
|
+
file_path = os.path.join(file_path, "rewritten_network")
|
|
1688
|
+
if not os.path.exists(file_path):
|
|
1689
|
+
try:
|
|
1690
|
+
os.mkdir(file_path, mode=0o700)
|
|
1691
|
+
except FileExistsError:
|
|
1692
|
+
pass
|
|
1693
|
+
file_name = f"{self._opt_cls_name}_{id(self)}.py"
|
|
1694
|
+
network_file = os.path.join(file_path, file_name)
|
|
1695
|
+
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
|
1696
|
+
source = self.get_code()
|
|
1697
|
+
f.write(source.encode('utf-8'))
|
|
1698
|
+
f.flush()
|
|
1699
|
+
os.fsync(f)
|
|
1700
|
+
tmp_module_path, tmp_module_file = os.path.split(network_file)
|
|
1701
|
+
tmp_module_name = tmp_module_file[:-3]
|
|
1702
|
+
sys.path.append(tmp_module_path)
|
|
1703
|
+
tmp_module = None
|
|
1704
|
+
|
|
1705
|
+
i = 0
|
|
1706
|
+
while not tmp_module:
|
|
1707
|
+
spec = importlib.util.spec_from_file_location(tmp_module_name, network_file)
|
|
1708
|
+
if spec:
|
|
1709
|
+
tmp_module = importlib.util.module_from_spec(spec)
|
|
1710
|
+
spec.loader.exec_module(tmp_module)
|
|
1711
|
+
else:
|
|
1712
|
+
logger.warning(f"load module {tmp_module_name} failed, retrying.")
|
|
1713
|
+
if i > 10:
|
|
1714
|
+
break
|
|
1715
|
+
time.sleep(0.5)
|
|
1716
|
+
i += 1
|
|
1717
|
+
if not tmp_module:
|
|
1718
|
+
raise ImportError(f"load module {tmp_module_name} failed.")
|
|
1719
|
+
# Save new module to sys.modules to support inspect.getsource().
|
|
1720
|
+
sys.modules[tmp_module_name] = tmp_module
|
|
1721
|
+
network_cls = getattr(tmp_module, self._opt_cls_name)
|
|
1722
|
+
if network_cls is None:
|
|
1723
|
+
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
|
1724
|
+
return network_cls
|
|
1725
|
+
|
|
1726
|
+
def _on_change(self, event: Event):
|
|
1727
|
+
self._modified = True
|
|
1728
|
+
self.changed(event)
|
|
1729
|
+
|
|
1730
|
+
def _cal_difference_set(self, input, other):
|
|
1731
|
+
"""Calculate different set of two sets."""
|
|
1732
|
+
set1 = set(input)
|
|
1733
|
+
set2 = set(other)
|
|
1734
|
+
return set1 - set2
|
|
1735
|
+
|
|
1736
|
+
def _merge_origin_property(self, new_net):
|
|
1737
|
+
"""Merge property of two network."""
|
|
1738
|
+
tmp = self._cal_difference_set(dir(self._origin_network), dir(new_net))
|
|
1739
|
+
new_attr_names = self._cal_difference_set(tmp, self._deleted_field.keys())
|
|
1740
|
+
for name in new_attr_names:
|
|
1741
|
+
setattr(new_net, name, getattr(self._origin_network, name))
|
|
1742
|
+
# merger cells
|
|
1743
|
+
cells = self._cal_difference_set(self._origin_network.name_cells().keys(), new_net.name_cells().keys())
|
|
1744
|
+
cells = self._cal_difference_set(cells, self._deleted_node)
|
|
1745
|
+
for c in cells:
|
|
1746
|
+
new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
|
|
1747
|
+
# merge primitives
|
|
1748
|
+
# pylint: disable=protected-access
|
|
1749
|
+
primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
|
|
1750
|
+
for p in primitives:
|
|
1751
|
+
new_net._primitives[p] = self._origin_network._primitives[p] # pylint: disable=protected-access
|
|
1752
|
+
|
|
1753
|
+
def _process_duplicate_name_modules(self, module_ast: ast.Module):
|
|
1754
|
+
"""Adjust names of imported modules with same name and different import path."""
|
|
1755
|
+
# {name1: [path1, path2, ...], ...}
|
|
1756
|
+
name_path_dict: Dict[str, List[str]] = {}
|
|
1757
|
+
# names of modules need to be suffixed: {name1: suffixed_name1, ...}
|
|
1758
|
+
name_need_suffix: Dict[str, str] = {}
|
|
1759
|
+
# used to record replace actions in ast.ImportFrom
|
|
1760
|
+
import_replacer = AstReplacer(None)
|
|
1761
|
+
self._tmp_replacers.append(import_replacer)
|
|
1762
|
+
|
|
1763
|
+
def suffix_alias(alias: ast.alias, suffix: int):
|
|
1764
|
+
"""suffix the name of alias in ast.ImportFrom"""
|
|
1765
|
+
new_name = f"{alias.asname}_{suffix}" if alias.asname else f"{alias.name}_{suffix}"
|
|
1766
|
+
import_replacer._trace.append((alias, 'asname', alias.asname, new_name)) # pylint: disable=protected-access
|
|
1767
|
+
alias.asname = new_name
|
|
1768
|
+
return new_name
|
|
1769
|
+
|
|
1770
|
+
def is_divider(ast_node):
|
|
1771
|
+
"""judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
|
|
1772
|
+
return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
|
|
1773
|
+
|
|
1774
|
+
def record_imports(ast_node: ast.ImportFrom):
|
|
1775
|
+
"""record name and path of imported modules to find the duplicate name modules."""
|
|
1776
|
+
for alias in ast_node.names[:]:
|
|
1777
|
+
name = alias.asname if alias.asname else alias.name
|
|
1778
|
+
if name == '*':
|
|
1779
|
+
continue
|
|
1780
|
+
# current name is firstly imported, just record it
|
|
1781
|
+
if name not in name_path_dict:
|
|
1782
|
+
name_path_dict[name] = [ast_node.module]
|
|
1783
|
+
continue
|
|
1784
|
+
# current name is imported before, check whether it is a duplicated name
|
|
1785
|
+
for idx, path in enumerate(name_path_dict[name]):
|
|
1786
|
+
if path.startswith(ast_node.module):
|
|
1787
|
+
# e.g. origin code is 'from a.b.c import A' and new code is 'from a.b import A'
|
|
1788
|
+
# then we update name_path_dict[name][idx] from 'a.b.c' to 'a.b' and update name to A_{idx}
|
|
1789
|
+
name_path_dict[name][idx] = ast_node.module
|
|
1790
|
+
if idx > 0:
|
|
1791
|
+
name_need_suffix[name] = suffix_alias(alias, idx)
|
|
1792
|
+
break
|
|
1793
|
+
elif ast_node.module.startswith(path):
|
|
1794
|
+
# e.g. origin code is 'from a.b import A' and new code is 'from a.b.c import A'
|
|
1795
|
+
# then we just need to update name to A_{idx}
|
|
1796
|
+
if idx > 0:
|
|
1797
|
+
name_need_suffix[name] = suffix_alias(alias, idx)
|
|
1798
|
+
break
|
|
1799
|
+
else:
|
|
1800
|
+
# current name is imported from a new path, save the path and update the name
|
|
1801
|
+
name_path_dict[name].append(ast_node.module)
|
|
1802
|
+
name_need_suffix[name] = suffix_alias(alias, len(name_path_dict[name]) - 1)
|
|
1803
|
+
|
|
1804
|
+
def suffix_names_in_ast(ast_node: Union[ast.ClassDef, ast.FunctionDef]):
|
|
1805
|
+
"""suffix names in ast.ClassDef or ast.FunctionDef"""
|
|
1806
|
+
if not name_need_suffix:
|
|
1807
|
+
return
|
|
1808
|
+
name_replacer = AstReplacer(ast_node)
|
|
1809
|
+
self._tmp_replacers.append(name_replacer)
|
|
1810
|
+
for name, new_name in name_need_suffix.items():
|
|
1811
|
+
name_replacer.replace_all(name, new_name)
|
|
1812
|
+
|
|
1813
|
+
for ast_node in module_ast.body:
|
|
1814
|
+
if isinstance(ast_node, ast.ImportFrom):
|
|
1815
|
+
record_imports(ast_node)
|
|
1816
|
+
if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
|
|
1817
|
+
suffix_names_in_ast(ast_node)
|
|
1818
|
+
if is_divider(ast_node):
|
|
1819
|
+
name_need_suffix.clear()
|