mindspore 2.4.0__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-310-darwin.so +0 -0
- mindspore/_c_expression.cpython-310-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-310-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1428 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Node class define of Rewrite. See detail in Node class docstring."""
|
|
16
|
+
from typing import Optional, Union, List, Dict
|
|
17
|
+
import ast
|
|
18
|
+
import inspect
|
|
19
|
+
from types import FunctionType
|
|
20
|
+
import sys
|
|
21
|
+
|
|
22
|
+
from mindspore.nn import Cell
|
|
23
|
+
from mindspore.ops import Primitive
|
|
24
|
+
from mindspore import log as logger
|
|
25
|
+
from ..api.scoped_value import ScopedValue, ValueType
|
|
26
|
+
from ..api.node_type import NodeType
|
|
27
|
+
from ..common.namespace import is_subtree
|
|
28
|
+
from ..common.error_log import error_str
|
|
29
|
+
from ..ast_helpers import AstModifier, AstReplacer, AstConverter
|
|
30
|
+
from ... import _checkparam as Validator
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if sys.version_info >= (3, 9):
|
|
34
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
35
|
+
else:
|
|
36
|
+
import astunparse
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LocalPrim(Primitive):
|
|
40
|
+
"""This class is used to indicate a local primitive instance"""
|
|
41
|
+
def __init__(self, prim_obj: type):
|
|
42
|
+
super().__init__("rewrite_local_prim")
|
|
43
|
+
self.prim_obj = prim_obj
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Node:
|
|
47
|
+
"""
|
|
48
|
+
Node is a data structure represents a source code line in network. For the most part, Node represents an operator
|
|
49
|
+
invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of
|
|
50
|
+
Node has different meaning in different type of node:
|
|
51
|
+
|
|
52
|
+
- CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
|
|
53
|
+
`targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
|
|
54
|
+
`kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
|
|
55
|
+
method. `func` is corresponding to func of call expression which means symbol of the cell-op.
|
|
56
|
+
- CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore.
|
|
57
|
+
`targets`, `args`, `kwargs` and `func_name` are as previous.
|
|
58
|
+
- CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`.
|
|
59
|
+
`targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name`
|
|
60
|
+
represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
|
|
61
|
+
method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
|
|
62
|
+
mapped to CallMethod node whose `func_name` is "PassThrough".
|
|
63
|
+
- Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not
|
|
64
|
+
supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care.
|
|
65
|
+
- Input: an input node represents an input of current network which also a parameter of forward method of Cell.
|
|
66
|
+
`targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter
|
|
67
|
+
of forward function. `kwargs` and `func_name` are don't-care.
|
|
68
|
+
- Output: an output node represents the output of current network which is corresponding to return statement of
|
|
69
|
+
forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and
|
|
70
|
+
`kwargs` are don't-care.
|
|
71
|
+
- Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so
|
|
72
|
+
`targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a
|
|
73
|
+
SymbolTree instance.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
77
|
+
func_name: Optional[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue], name: str,
|
|
78
|
+
instance):
|
|
79
|
+
"""
|
|
80
|
+
Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
|
|
81
|
+
as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and
|
|
82
|
+
`create_output_node`, etc. rather than invoking constructor of Node directly.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
node_type (NodeType): A NodeType as type of Node.
|
|
86
|
+
ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should
|
|
87
|
+
not be None except when node type is Unknown.
|
|
88
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
89
|
+
func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
|
|
90
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
91
|
+
kwargs (Dict[str, ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
92
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
93
|
+
Name of node also used as field name in network class.
|
|
94
|
+
instance: Object in network corresponding to this node.
|
|
95
|
+
"""
|
|
96
|
+
self._node_type: NodeType = node_type
|
|
97
|
+
self._ast_node: Optional[ast.AST] = ast_node
|
|
98
|
+
self._attribute: {str, object} = {}
|
|
99
|
+
if node_type in (NodeType.CallModule, NodeType.CallCell, NodeType.CallPrimitive):
|
|
100
|
+
self._attribute = Node._get_cell_or_prim_op_attribute(instance)
|
|
101
|
+
self._instance = instance
|
|
102
|
+
self._name = name
|
|
103
|
+
self._func_name: Optional[ScopedValue] = func_name
|
|
104
|
+
self._targets: [ScopedValue] = targets if targets is not None else []
|
|
105
|
+
self._args_num = len(args) if args is not None else 0
|
|
106
|
+
self._kwargs_num = len(kwargs) if kwargs is not None else 0
|
|
107
|
+
self._normalized_args_keys = [] # for saving args' order
|
|
108
|
+
self._normalized_args = self._get_normalized_args(args, kwargs)
|
|
109
|
+
# position in graph nodes list
|
|
110
|
+
# it will affect code-order of python code
|
|
111
|
+
self._prev: Optional[Node] = None
|
|
112
|
+
self._next: Optional[Node] = None
|
|
113
|
+
# A handler of SymbolTree current node belonging to
|
|
114
|
+
self._belong_tree = None
|
|
115
|
+
# A handler of NodeManager current node belonging to
|
|
116
|
+
self._node_manager = None
|
|
117
|
+
# A dict that records which target of which Node current Node's argument come from
|
|
118
|
+
self._arg_providers: {int: (Node, int)} = {}
|
|
119
|
+
# A dict that records which argument of which Node uses current Node's target
|
|
120
|
+
self._target_users: {int: [(Node, int)]} = {}
|
|
121
|
+
# Indicate this node represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs)
|
|
122
|
+
self._type_cls = None
|
|
123
|
+
# Indicate this node represent the initialize of a class type, e.g. abs_inst = P.Abs()
|
|
124
|
+
self._init_cls = None
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
128
|
+
func_name: Union[ScopedValue, str], args: [ScopedValue] = None,
|
|
129
|
+
kwargs: {str: ScopedValue}=None, name: str = ""):
|
|
130
|
+
"""
|
|
131
|
+
Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an
|
|
132
|
+
invoking to cell-op.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node`
|
|
136
|
+
should not be None currently.
|
|
137
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
138
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
139
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
140
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
141
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
142
|
+
Name of node also used as field name in network class.
|
|
143
|
+
"""
|
|
144
|
+
if args is None:
|
|
145
|
+
args = []
|
|
146
|
+
if kwargs is None:
|
|
147
|
+
kwargs = {}
|
|
148
|
+
if isinstance(func_name, str):
|
|
149
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
150
|
+
new_targets = Node._handle_targets(targets)
|
|
151
|
+
if ast_node is None:
|
|
152
|
+
raise RuntimeError("Input ast_node is None")
|
|
153
|
+
return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
|
|
154
|
+
|
|
155
|
+
@classmethod
|
|
156
|
+
def create_python_node(cls, ast_node: ast.AST, name: str = "", instance=None):
|
|
157
|
+
"""
|
|
158
|
+
Class method of Node. Instantiate an instance of node whose type is Python. A Python node represents some python
|
|
159
|
+
statement is not supported by Rewrite or ignored by Rewrite.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
163
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
164
|
+
Name of node also used as field name in network class.
|
|
165
|
+
instance: An object corresponding to this node in network.
|
|
166
|
+
"""
|
|
167
|
+
return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance)
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None,
|
|
171
|
+
name: str = ""):
|
|
172
|
+
"""
|
|
173
|
+
Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of
|
|
174
|
+
SymbolTree which is corresponding to parameters of forward function.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
178
|
+
arg_name (str): A string represents name of parameter.
|
|
179
|
+
default ([ScopedValue, optional]): An instance of ScopedValue represents default value of parameter.
|
|
180
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
181
|
+
Name of node also used as field name in network class.
|
|
182
|
+
"""
|
|
183
|
+
target = ScopedValue.create_naming_value(arg_name)
|
|
184
|
+
if default is None:
|
|
185
|
+
args = []
|
|
186
|
+
else:
|
|
187
|
+
args = [default]
|
|
188
|
+
if ast_node is None:
|
|
189
|
+
ast_node = ast.arg(arg_name, annotation="")
|
|
190
|
+
return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
|
|
191
|
+
|
|
192
|
+
@classmethod
|
|
193
|
+
def create_output_node(cls, ast_node: ast.AST, return_value: [ScopedValue], name: str = "return"):
|
|
194
|
+
"""
|
|
195
|
+
Class method of Node. Instantiate an instance of node whose type is Output. An Output node represents output of
|
|
196
|
+
SymbolTree which is corresponding to return statement of forward function.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
200
|
+
return_values (list[str]): A list of string represents name of return values.
|
|
201
|
+
name (ScopedValue): An instance of ScopedValue represents name of node.
|
|
202
|
+
"""
|
|
203
|
+
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), return_value, {},
|
|
204
|
+
name, None)
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
|
|
208
|
+
op_type: ScopedValue, args: [ScopedValue], name: str = ""):
|
|
209
|
+
"""
|
|
210
|
+
Class method of Node. Instantiate an instance of node whose type is `MathOps` .
|
|
211
|
+
A mathops node is used to represent a node with mathematical operations, such as
|
|
212
|
+
`y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
|
|
216
|
+
node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
|
|
217
|
+
ast.Compare.
|
|
218
|
+
targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
|
|
219
|
+
See detail in docstring of Node class.
|
|
220
|
+
op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
|
|
221
|
+
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
|
|
222
|
+
sequentially in the list.
|
|
223
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
224
|
+
Name of node also used as field name in network class. The format of mathops node name
|
|
225
|
+
is 'AstNodeName_AstOpName_n'.
|
|
226
|
+
"""
|
|
227
|
+
return cls(NodeType.MathOps, ast_node, targets, op_type, args, None, name, None)
|
|
228
|
+
|
|
229
|
+
@staticmethod
|
|
230
|
+
def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
|
231
|
+
kwargs: {str: ScopedValue}=None):
|
|
232
|
+
"""
|
|
233
|
+
Create a node that corresponds to a function call.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
function (FunctionType): The function to be called.
|
|
237
|
+
targets (list[str]): indicates output names. Used as targets of an assign statement in source code.
|
|
238
|
+
args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
239
|
+
source code. Default: ``None`` , which indicates the `function` has no args inputs.
|
|
240
|
+
kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
241
|
+
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
242
|
+
code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
An instance of `Node`.
|
|
246
|
+
"""
|
|
247
|
+
if args is None:
|
|
248
|
+
args = []
|
|
249
|
+
if kwargs is None:
|
|
250
|
+
kwargs = {}
|
|
251
|
+
targets = Node._handle_targets(targets)
|
|
252
|
+
func_name = function.__name__
|
|
253
|
+
func_scope_name = ScopedValue.create_naming_value(func_name)
|
|
254
|
+
node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
|
|
255
|
+
return node
|
|
256
|
+
|
|
257
|
+
@classmethod
|
|
258
|
+
def inner_create_call_function(cls, node_name: str, ast_node: ast.Assign, func_name: ScopedValue, func_obj: object,
|
|
259
|
+
targets: List[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue]):
|
|
260
|
+
'''
|
|
261
|
+
Instantiate an instance of node whose type is `CallFunction`.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
node_name (str): Name of node.
|
|
265
|
+
func_name (ScopedValue): Name of function.
|
|
266
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
267
|
+
func_obj (Object): An object of function. See detail in docstring of Node class.
|
|
268
|
+
targets (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
269
|
+
args (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
270
|
+
kwargs (Dict[str, ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
271
|
+
class.
|
|
272
|
+
'''
|
|
273
|
+
from . import CallFunction
|
|
274
|
+
# create CallFunction node
|
|
275
|
+
return CallFunction(targets, func_name, args, kwargs, node_name, ast_node, None, None, func_obj, False)
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
279
|
+
args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "",
|
|
280
|
+
is_sub_net: bool = False):
|
|
281
|
+
"""
|
|
282
|
+
Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
283
|
+
If op is custom defined, it is treated by TreeNode.
|
|
284
|
+
A `CallCell` node represents an invoking to cell-op.
|
|
285
|
+
A `CallPrimitive` node represents an invoking to primitive-op.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
289
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
290
|
+
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
291
|
+
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
292
|
+
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
293
|
+
class.
|
|
294
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
295
|
+
`SymbolTree`. Name of node also used as field name in network class.
|
|
296
|
+
is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the
|
|
297
|
+
`cell` to a TreeNode, else a CallCell Node. Default is a False.
|
|
298
|
+
"""
|
|
299
|
+
Validator.check_value_type("op", op, [Cell, Primitive], "Node")
|
|
300
|
+
if ast_node is not None:
|
|
301
|
+
Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node")
|
|
302
|
+
Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
|
|
303
|
+
if args is not None:
|
|
304
|
+
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
305
|
+
if kwargs is not None:
|
|
306
|
+
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
307
|
+
if args is None:
|
|
308
|
+
args = []
|
|
309
|
+
if kwargs is None:
|
|
310
|
+
kwargs = {}
|
|
311
|
+
Validator.check_value_type("node_name", node_name, [str], "Node")
|
|
312
|
+
new_targets = Node._handle_targets(targets)
|
|
313
|
+
if isinstance(node_name, str):
|
|
314
|
+
func_name = ScopedValue.create_naming_value(node_name)
|
|
315
|
+
else:
|
|
316
|
+
func_name = node_name
|
|
317
|
+
if is_sub_net and is_subtree(op):
|
|
318
|
+
from ..symbol_tree import SymbolTreeBuilder
|
|
319
|
+
stb = SymbolTreeBuilder(op)
|
|
320
|
+
stree = stb.build()
|
|
321
|
+
replacer = AstReplacer(stree.get_class_ast())
|
|
322
|
+
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
|
323
|
+
return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op)
|
|
324
|
+
|
|
325
|
+
return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name)
|
|
326
|
+
|
|
327
|
+
@classmethod
|
|
328
|
+
def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
|
|
329
|
+
func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
|
|
330
|
+
node_name: str = ""):
|
|
331
|
+
"""
|
|
332
|
+
Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
|
333
|
+
A `CallCell` node represents an invoking to cell-op.
|
|
334
|
+
A `CallPrimitive` node represents an invoking to primitive-op.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
|
338
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
339
|
+
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
340
|
+
func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
341
|
+
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
342
|
+
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
343
|
+
class.
|
|
344
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
345
|
+
`SymbolTree`. Name of node also used as field name in network class.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
if not isinstance(op, (Cell, Primitive)):
|
|
349
|
+
raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
|
|
350
|
+
if isinstance(op, Cell):
|
|
351
|
+
node_type = NodeType.CallCell
|
|
352
|
+
else:
|
|
353
|
+
node_type = NodeType.CallPrimitive
|
|
354
|
+
return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op)
|
|
355
|
+
|
|
356
|
+
@staticmethod
|
|
357
|
+
def _get_construct_arg_names(parameters):
|
|
358
|
+
"""
|
|
359
|
+
Static method of `Node`. Get parameters' names of the construct function.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
parameters (MappingProxyType): An ordered mapping of parameters' names to the corresponding Parameter
|
|
363
|
+
objects.
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
RuntimeError: Invalid parameter kind.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
- arg_names, Parameters' names, contain parameters of types in [POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD].
|
|
370
|
+
- var_positional_name, Name of VAR_POSITIONAL parameters.
|
|
371
|
+
- var_keyword_name, Name of VAR_KEYWORD parameters.
|
|
372
|
+
"""
|
|
373
|
+
position_only_names: [str] = []
|
|
374
|
+
positional_or_keyword_names: [str] = []
|
|
375
|
+
var_positional_name = None
|
|
376
|
+
keyword_only_names: [str] = []
|
|
377
|
+
var_keyword_name = None
|
|
378
|
+
for name, para in parameters.items():
|
|
379
|
+
if para.kind == inspect.Parameter.POSITIONAL_ONLY: # parameters which appear before a '/'
|
|
380
|
+
position_only_names.append(name)
|
|
381
|
+
elif para.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: # parameters which appear before '*' or '*args'
|
|
382
|
+
positional_or_keyword_names.append(name)
|
|
383
|
+
elif para.kind == inspect.Parameter.VAR_POSITIONAL: # corresponds to a '*args'
|
|
384
|
+
var_positional_name = name
|
|
385
|
+
elif para.kind == inspect.Parameter.KEYWORD_ONLY: # parameters which appear after '*' and before '**'
|
|
386
|
+
keyword_only_names.append(name)
|
|
387
|
+
elif para.kind == inspect.Parameter.VAR_KEYWORD: # corresponds to a '**kwargs'
|
|
388
|
+
var_keyword_name = name
|
|
389
|
+
else:
|
|
390
|
+
raise RuntimeError("invalid parameter kind:", para.kind)
|
|
391
|
+
if "self" in position_only_names:
|
|
392
|
+
position_only_names.remove("self")
|
|
393
|
+
if "self" in positional_or_keyword_names:
|
|
394
|
+
positional_or_keyword_names.remove("self")
|
|
395
|
+
names = (position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names,
|
|
396
|
+
var_keyword_name)
|
|
397
|
+
return names
|
|
398
|
+
|
|
399
|
+
@staticmethod
|
|
400
|
+
def _map_args_names(names: tuple, args: [ScopedValue], kwargs: {str: ScopedValue},
|
|
401
|
+
normalized_args_keys: [str], normalized_args: {str: ScopedValue}):
|
|
402
|
+
"""
|
|
403
|
+
Fill in normalized_args according to the order of parameters of construct func.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
names (tuple): Parameters' name got from construct func.
|
|
407
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
408
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
409
|
+
normalized_args (dict{str: ScopedValue}): The normalized args to be filled.
|
|
410
|
+
|
|
411
|
+
Raises:
|
|
412
|
+
RuntimeError: Input args are invalid.
|
|
413
|
+
RuntimeError: Arg name already exist in kwargs.
|
|
414
|
+
RuntimeError: Input kwargs invalid.
|
|
415
|
+
"""
|
|
416
|
+
position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names, var_keyword_name = \
|
|
417
|
+
names
|
|
418
|
+
for arg_index, arg in enumerate(args):
|
|
419
|
+
if arg_index < len(position_only_names):
|
|
420
|
+
arg_key = position_only_names[arg_index]
|
|
421
|
+
elif arg_index < len(position_only_names) + len(positional_or_keyword_names):
|
|
422
|
+
arg_key = positional_or_keyword_names[arg_index - len(position_only_names)]
|
|
423
|
+
elif var_positional_name:
|
|
424
|
+
arg_key = "{}_{}".format(var_positional_name, arg_index)
|
|
425
|
+
else:
|
|
426
|
+
raise RuntimeError("Input args are invalid.")
|
|
427
|
+
|
|
428
|
+
if arg_key in kwargs.keys():
|
|
429
|
+
raise RuntimeError("Arg name already exist in kwargs.")
|
|
430
|
+
normalized_args[arg_key] = arg
|
|
431
|
+
normalized_args_keys.append(arg_key)
|
|
432
|
+
|
|
433
|
+
# add kwargs according to parameters' order
|
|
434
|
+
parameters_order: [str] = []
|
|
435
|
+
parameters_order.extend(position_only_names)
|
|
436
|
+
parameters_order.extend(positional_or_keyword_names)
|
|
437
|
+
parameters_order.append(var_keyword_name)
|
|
438
|
+
parameters_order.extend(keyword_only_names)
|
|
439
|
+
parameters_order.append(var_keyword_name)
|
|
440
|
+
|
|
441
|
+
sorted_kwargs = []
|
|
442
|
+
var_keyword_count = len(parameters_order)
|
|
443
|
+
for arg_key, value in kwargs.items():
|
|
444
|
+
if arg_key not in parameters_order and not var_keyword_name:
|
|
445
|
+
raise RuntimeError("Input kwargs invalid.")
|
|
446
|
+
if arg_key in parameters_order:
|
|
447
|
+
sorted_kwargs.append([arg_key, value, parameters_order.index(arg_key)])
|
|
448
|
+
else:
|
|
449
|
+
sorted_kwargs.append([arg_key, value, var_keyword_count])
|
|
450
|
+
var_keyword_count += 1
|
|
451
|
+
|
|
452
|
+
sorted_kwargs.sort(key=lambda x: x[2])
|
|
453
|
+
for sorted_kwarg in sorted_kwargs:
|
|
454
|
+
normalized_args[sorted_kwarg[0]] = sorted_kwarg[1]
|
|
455
|
+
normalized_args_keys.append(sorted_kwarg[0])
|
|
456
|
+
|
|
457
|
+
@staticmethod
|
|
458
|
+
def _handle_custom_obj_in_args(args: [ScopedValue]) -> [ScopedValue]:
|
|
459
|
+
"""
|
|
460
|
+
Convert CustomObjValue type argument to NamingValue type argument.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
args (list[ScopedValue]): A list of instance of ScopedValue to be converted.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
A list of instance of ScopedValue which have been converted.
|
|
467
|
+
"""
|
|
468
|
+
result = []
|
|
469
|
+
for arg in args:
|
|
470
|
+
if not isinstance(arg, ScopedValue):
|
|
471
|
+
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
|
472
|
+
if arg.type == ValueType.CustomObjValue:
|
|
473
|
+
logger.info("custom-object exist in args, should be replace before compile")
|
|
474
|
+
result.append(ScopedValue.create_naming_value("custom-object", "self"))
|
|
475
|
+
else:
|
|
476
|
+
result.append(arg)
|
|
477
|
+
return result
|
|
478
|
+
|
|
479
|
+
@staticmethod
|
|
480
|
+
def _handle_custom_obj_in_kwargs(kwargs: {str: ScopedValue}) -> {str: ScopedValue}:
|
|
481
|
+
"""
|
|
482
|
+
Convert CustomObjValue type argument to NamingValue type argument.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
kwargs (dict{str: ScopedValue}): A str to instance of ScopedValue dict whose value to be converted.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
A str to instance of ScopedValue dict whose value has be converted.
|
|
489
|
+
"""
|
|
490
|
+
result: {str, ScopedValue} = {}
|
|
491
|
+
for arg, value in kwargs.items():
|
|
492
|
+
if not isinstance(value, ScopedValue):
|
|
493
|
+
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
494
|
+
if value.type == ValueType.CustomObjValue:
|
|
495
|
+
result[arg] = ScopedValue.create_naming_value("custom-object", "self")
|
|
496
|
+
else:
|
|
497
|
+
result[arg] = value
|
|
498
|
+
return result
|
|
499
|
+
|
|
500
|
+
@staticmethod
|
|
501
|
+
def _handle_targets(targets: [Union[ScopedValue, str]]) -> [ScopedValue]:
|
|
502
|
+
"""
|
|
503
|
+
Normalize targets to be a list of ScopedValue. If target is a str, it will be converted to NamingValue type
|
|
504
|
+
ScopedValue.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
targets (Union[ScopedValue, str]]): A list whose element could be a ScopedValue or a str to be normalized.
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
A list of instance of ScopedValue which have been converted.
|
|
511
|
+
"""
|
|
512
|
+
if not isinstance(targets, list):
|
|
513
|
+
raise TypeError("targets should be list, got: ", type(targets))
|
|
514
|
+
results = []
|
|
515
|
+
for target in targets:
|
|
516
|
+
if isinstance(target, str):
|
|
517
|
+
scope = ""
|
|
518
|
+
name = target
|
|
519
|
+
if target.count('.') > 0:
|
|
520
|
+
scope, name = target.rsplit('.', 1)
|
|
521
|
+
results.append(ScopedValue.create_naming_value(name, scope))
|
|
522
|
+
elif isinstance(target, ScopedValue):
|
|
523
|
+
results.append(target)
|
|
524
|
+
else:
|
|
525
|
+
raise RuntimeError("Invalid symbol type: ", target)
|
|
526
|
+
return results
|
|
527
|
+
|
|
528
|
+
@staticmethod
|
|
529
|
+
def _get_cell_or_prim_op_attribute(obj) -> dict:
|
|
530
|
+
"""
|
|
531
|
+
Find attributes of cell-op or primitive-op.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
obj: A cell-op or a primitive-op.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
A dict represents attributes of input 'obj'.
|
|
538
|
+
"""
|
|
539
|
+
attributes = {}
|
|
540
|
+
if obj is None:
|
|
541
|
+
return attributes
|
|
542
|
+
for k, v in obj.__dict__.items():
|
|
543
|
+
if k.startswith("_"):
|
|
544
|
+
continue
|
|
545
|
+
attributes[k] = v
|
|
546
|
+
attributes["cls"] = obj.__class__
|
|
547
|
+
return attributes
|
|
548
|
+
|
|
549
|
+
def get_type_cls(self) -> object:
|
|
550
|
+
"""Get the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
|
|
551
|
+
return self._type_cls
|
|
552
|
+
|
|
553
|
+
def set_type_cls(self, x):
|
|
554
|
+
"""Set the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)"""
|
|
555
|
+
self._type_cls = x
|
|
556
|
+
|
|
557
|
+
def get_init_cls(self) -> object:
|
|
558
|
+
"""Get the class type object initialized by this node, e.g. abs_inst = P.Abs()"""
|
|
559
|
+
return self._init_cls
|
|
560
|
+
|
|
561
|
+
def set_init_cls(self, x):
|
|
562
|
+
"""Set the class type object initialized by this node"""
|
|
563
|
+
self._init_cls = x
|
|
564
|
+
|
|
565
|
+
def get_prev(self) -> 'Node':
|
|
566
|
+
"""
|
|
567
|
+
Get previous node of current node in source code order.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
An instance of Node as previous node.
|
|
571
|
+
"""
|
|
572
|
+
return self._prev
|
|
573
|
+
|
|
574
|
+
def get_next(self) -> 'Node':
|
|
575
|
+
"""
|
|
576
|
+
Get next node of current node in source code order.
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
An instance of Node as next node.
|
|
580
|
+
"""
|
|
581
|
+
return self._next
|
|
582
|
+
|
|
583
|
+
def set_prev(self, node: 'Node'):
|
|
584
|
+
"""
|
|
585
|
+
Set previous node of current node.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
node (Node): Node to be set as previous node of current node.
|
|
589
|
+
"""
|
|
590
|
+
self._prev = node
|
|
591
|
+
|
|
592
|
+
def set_next(self, node: 'Node'):
|
|
593
|
+
"""
|
|
594
|
+
Set next node of current node.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
node (Node): Node to be set as next node of current node.
|
|
598
|
+
"""
|
|
599
|
+
self._next = node
|
|
600
|
+
|
|
601
|
+
def get_ast(self) -> Optional[ast.AST]:
|
|
602
|
+
"""
|
|
603
|
+
Getter of _ast_node.
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
An instance of ast.AST if self._ast_node if not None else None.
|
|
607
|
+
"""
|
|
608
|
+
return self._ast_node
|
|
609
|
+
|
|
610
|
+
def set_ast(self, ast_node: ast.AST):
|
|
611
|
+
"""
|
|
612
|
+
Setter of _ast_node.
|
|
613
|
+
|
|
614
|
+
Args:
|
|
615
|
+
ast_node (ast.AST): An instance of ast.AST as new value for _ast_node.
|
|
616
|
+
"""
|
|
617
|
+
if not isinstance(ast_node, ast.AST):
|
|
618
|
+
raise TypeError("ast_node should be ast.AST, got: ", type(ast_node))
|
|
619
|
+
self._ast_node = ast_node
|
|
620
|
+
|
|
621
|
+
def get_belong_symbol_tree(self):
|
|
622
|
+
"""Get the symbol tree to which node belongs."""
|
|
623
|
+
return self._belong_tree
|
|
624
|
+
|
|
625
|
+
def set_belong_symbol_tree(self, symbol_tree):
|
|
626
|
+
"""Set the symbol tree to which node belongs."""
|
|
627
|
+
self._belong_tree = symbol_tree
|
|
628
|
+
|
|
629
|
+
def get_node_manager(self):
|
|
630
|
+
"""Get the NodeManager current node belongs to."""
|
|
631
|
+
return self._node_manager
|
|
632
|
+
|
|
633
|
+
def set_node_manager(self, node_manager):
|
|
634
|
+
"""Set NodeManager current node belongs."""
|
|
635
|
+
self._node_manager = node_manager
|
|
636
|
+
|
|
637
|
+
def isolate(self):
|
|
638
|
+
"""Link prev node to next node and isolate node from source code order list."""
|
|
639
|
+
origin_prev: Optional[Node] = self.get_prev()
|
|
640
|
+
origin_next: Optional[Node] = self.get_next()
|
|
641
|
+
if origin_prev is not None:
|
|
642
|
+
origin_prev.set_next(origin_next)
|
|
643
|
+
if origin_next is not None:
|
|
644
|
+
origin_next.set_prev(origin_prev)
|
|
645
|
+
self.set_prev(None)
|
|
646
|
+
self.set_next(None)
|
|
647
|
+
|
|
648
|
+
def insert_before(self, node: 'Node'):
|
|
649
|
+
"""
|
|
650
|
+
Insert a node before current node in source code list. Note that topological order is not determined here.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
node (Node): An instance of node to be inserted in.
|
|
654
|
+
"""
|
|
655
|
+
node.isolate()
|
|
656
|
+
origin_prev: Optional[Node] = self.get_prev()
|
|
657
|
+
if origin_prev is not None:
|
|
658
|
+
origin_prev.set_next(node)
|
|
659
|
+
node.set_prev(origin_prev)
|
|
660
|
+
node.set_next(self)
|
|
661
|
+
self.set_prev(node)
|
|
662
|
+
|
|
663
|
+
def insert_after(self, node: 'Node'):
|
|
664
|
+
"""
|
|
665
|
+
Insert a node after current node in source code list. Note that topological order is not determined here.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
node (Node): An instance of node to be inserted in.
|
|
669
|
+
"""
|
|
670
|
+
node.isolate()
|
|
671
|
+
origin_next: Optional[Node] = self.get_next()
|
|
672
|
+
self.set_next(node)
|
|
673
|
+
node.set_prev(self)
|
|
674
|
+
node.set_next(origin_next)
|
|
675
|
+
if origin_next is not None:
|
|
676
|
+
origin_next.set_prev(node)
|
|
677
|
+
|
|
678
|
+
def get_inputs(self) -> ['Node']:
|
|
679
|
+
"""
|
|
680
|
+
Get input nodes of current node in topological order.
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
A list of instances of Node as input nodes.
|
|
684
|
+
"""
|
|
685
|
+
inputs = []
|
|
686
|
+
for arg_provider in self.get_arg_providers().values():
|
|
687
|
+
if not arg_provider:
|
|
688
|
+
continue
|
|
689
|
+
inputs.append(arg_provider[0])
|
|
690
|
+
return inputs
|
|
691
|
+
|
|
692
|
+
def get_users(self) -> ['Node']:
|
|
693
|
+
"""
|
|
694
|
+
Get user nodes of current node in topological order.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
A list of instances of Node as user nodes.
|
|
698
|
+
"""
|
|
699
|
+
users = []
|
|
700
|
+
for target_users in self.get_target_users().values():
|
|
701
|
+
if not target_users:
|
|
702
|
+
continue
|
|
703
|
+
for (user, _) in target_users:
|
|
704
|
+
if user not in users:
|
|
705
|
+
users.append(user)
|
|
706
|
+
return users
|
|
707
|
+
|
|
708
|
+
def get_targets(self) -> [ScopedValue]:
|
|
709
|
+
"""
|
|
710
|
+
Getter of _targets.
|
|
711
|
+
|
|
712
|
+
- When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings
|
|
713
|
+
represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
|
|
714
|
+
ast.Assign.
|
|
715
|
+
- When node_type of current node is Input, `targets` should have only one element which is a string represents
|
|
716
|
+
name of parameter of function.
|
|
717
|
+
- When node_type of current node is Python or Output, `targets` are don't-care.
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
A list of instances of ScopedValue as targets of node.
|
|
721
|
+
"""
|
|
722
|
+
return self._targets
|
|
723
|
+
|
|
724
|
+
def set_targets(self, targets: [ScopedValue]):
|
|
725
|
+
"""
|
|
726
|
+
Setter of _targets.
|
|
727
|
+
|
|
728
|
+
Note:
|
|
729
|
+
This interface can only be called before node been inserted into symbol-tree because target will be unique
|
|
730
|
+
while insert into symbol-tree, in other word, set_targets is not a user-interface.
|
|
731
|
+
|
|
732
|
+
When `_targets` is updated, corresponding ast node would be updated also.
|
|
733
|
+
|
|
734
|
+
When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings
|
|
735
|
+
represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets
|
|
736
|
+
of ast.Assign.
|
|
737
|
+
|
|
738
|
+
When node_type of current node is Input, `targets` should have only one element which is a string represents
|
|
739
|
+
name of parameter of function.
|
|
740
|
+
|
|
741
|
+
When node_type of current node is Python or Output, `targets` are don't-care.
|
|
742
|
+
|
|
743
|
+
Args:
|
|
744
|
+
targets ([ScopedValue]): A list of instances of ScopedValue as new targets.
|
|
745
|
+
"""
|
|
746
|
+
self._targets = targets
|
|
747
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
|
|
748
|
+
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
|
|
749
|
+
NodeType.MathOps):
|
|
750
|
+
self._sync_assign_targets_to_ast()
|
|
751
|
+
|
|
752
|
+
def get_func_name(self) -> ScopedValue:
|
|
753
|
+
"""
|
|
754
|
+
Getter of `_func_name`. See detail in docstring of Node class for meaning of func.
|
|
755
|
+
|
|
756
|
+
Returns:
|
|
757
|
+
An instance of ScopedValue.
|
|
758
|
+
"""
|
|
759
|
+
return self._func_name
|
|
760
|
+
|
|
761
|
+
def set_func_name(self, func_name: ScopedValue):
|
|
762
|
+
"""
|
|
763
|
+
Setter of `_func_name`. See detail in docstring of Node class for meaning of func.
|
|
764
|
+
|
|
765
|
+
Note:
|
|
766
|
+
When `_func_name` is updated, corresponding ast node would be updated also.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
func (ScopedValue): An instance of ScopedValue as new func.
|
|
770
|
+
"""
|
|
771
|
+
self._func_name = func_name
|
|
772
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
|
|
773
|
+
self._sync_assign_func_name_to_ast()
|
|
774
|
+
|
|
775
|
+
def get_name(self) -> str:
|
|
776
|
+
"""
|
|
777
|
+
Getter of `_name`.
|
|
778
|
+
|
|
779
|
+
Returns:
|
|
780
|
+
A str represents name of node.
|
|
781
|
+
"""
|
|
782
|
+
return self._name
|
|
783
|
+
|
|
784
|
+
def set_name(self, name: str):
|
|
785
|
+
"""
|
|
786
|
+
Setter of `_name`.
|
|
787
|
+
|
|
788
|
+
Args:
|
|
789
|
+
name (str): A str as new name of node.
|
|
790
|
+
"""
|
|
791
|
+
self._name = name
|
|
792
|
+
|
|
793
|
+
def get_node_type(self) -> NodeType:
|
|
794
|
+
"""
|
|
795
|
+
Get the node_type of current node.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
A NodeType as node_type of node.
|
|
799
|
+
"""
|
|
800
|
+
return self._node_type
|
|
801
|
+
|
|
802
|
+
def get_instance_type(self) -> type:
|
|
803
|
+
"""
|
|
804
|
+
Get the instance_type of current node.
|
|
805
|
+
|
|
806
|
+
- When node_type of current node is CallCell, instance_type is type of cell-op.
|
|
807
|
+
- When node_type of current node is CallPrimitive, instance_type is type of primitive-op.
|
|
808
|
+
- When node_type of current node is Tree, instance_type is type of network-cell.
|
|
809
|
+
- When node_type of current node is Python, Input, Output or CallMethod, instance_type should be NoneType
|
|
810
|
+
|
|
811
|
+
Returns:
|
|
812
|
+
A type.
|
|
813
|
+
"""
|
|
814
|
+
if isinstance(self._instance, LocalPrim):
|
|
815
|
+
return self._instance.prim_obj
|
|
816
|
+
if inspect.isfunction(self._instance):
|
|
817
|
+
return self._instance
|
|
818
|
+
return type(self._instance)
|
|
819
|
+
|
|
820
|
+
def get_instance(self):
|
|
821
|
+
"""
|
|
822
|
+
Get the instance of current node.
|
|
823
|
+
|
|
824
|
+
- When node_type of current node is CallCell, instance is an instance of Cell.
|
|
825
|
+
- When node_type of current node is CallPrimitive, instance is an instance of primitive.
|
|
826
|
+
- When node_type of current node is Tree, instance is an instance of network-cell.
|
|
827
|
+
- When node_type of current node is Python, Input, Output or CallMethod, instance should be None
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
A object.
|
|
831
|
+
"""
|
|
832
|
+
return self._instance
|
|
833
|
+
|
|
834
|
+
def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None):
|
|
835
|
+
"""
|
|
836
|
+
Set argument by another Node.
|
|
837
|
+
Note that when _normalized_args is updated, corresponding ast node would be updated also.
|
|
838
|
+
|
|
839
|
+
Args:
|
|
840
|
+
arg_idx (int): Indicate which input being modified.
|
|
841
|
+
node (Node): Node as new input. Can be a node or name of node.
|
|
842
|
+
out_idx ([int, optional]): Indicate which output of `node` as new argument. Default is None which means use
|
|
843
|
+
first output of `node_to_link` as new input.
|
|
844
|
+
|
|
845
|
+
Raises:
|
|
846
|
+
ValueError: If `arg_idx` is out of range.
|
|
847
|
+
ValueError: If `node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
|
|
848
|
+
"""
|
|
849
|
+
Validator.check_value_type("node", node, [Node], "Node")
|
|
850
|
+
Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
|
|
851
|
+
if out_idx is None:
|
|
852
|
+
if len(node.get_targets()) != 1:
|
|
853
|
+
raise ValueError("node should has one output when out_idx is not provided")
|
|
854
|
+
out_idx = 0
|
|
855
|
+
Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
|
|
856
|
+
new_arg = node.get_targets()[out_idx]
|
|
857
|
+
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
|
|
858
|
+
self._sync_arg()
|
|
859
|
+
|
|
860
|
+
def set_arg(self, arg: Union[ScopedValue, str], index: int) -> (ScopedValue, ScopedValue):
|
|
861
|
+
"""
|
|
862
|
+
Set argument of `node`.
|
|
863
|
+
Note that when _normalized_args is updated, corresponding ast node would be updated also.
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
index (int): Indicate which input being modified.
|
|
867
|
+
arg (Union[ScopedValue, str]): New argument to been set.
|
|
868
|
+
|
|
869
|
+
Raises:
|
|
870
|
+
ValueError: If `index` is out of range.
|
|
871
|
+
"""
|
|
872
|
+
Validator.check_int_range(index, 0, self._args_num, Validator.INC_LEFT, "index")
|
|
873
|
+
Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
|
|
874
|
+
if isinstance(arg, str):
|
|
875
|
+
arg = ScopedValue.create_naming_value(arg)
|
|
876
|
+
old_arg = self._normalized_args.get(self._normalized_args_keys[index])
|
|
877
|
+
self._normalized_args[self._normalized_args_keys[index]] = arg
|
|
878
|
+
self._sync_arg()
|
|
879
|
+
return arg, old_arg
|
|
880
|
+
|
|
881
|
+
def set_args(self, args: [ScopedValue]):
|
|
882
|
+
"""
|
|
883
|
+
Set arguments of `node`.
|
|
884
|
+
Note that when _normalized_args is updated, corresponding ast node would be updated also.
|
|
885
|
+
|
|
886
|
+
Args:
|
|
887
|
+
args (list[ScopedValue]): New arguments to been set.
|
|
888
|
+
|
|
889
|
+
Raises:
|
|
890
|
+
TypeError: Element of new argument is not an instance of ScopedValue.
|
|
891
|
+
"""
|
|
892
|
+
Validator.check_int_range(len(args), 0, self._args_num, Validator.INC_LEFT, "Length of args")
|
|
893
|
+
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
894
|
+
for arg_index, arg in enumerate(args):
|
|
895
|
+
if not isinstance(arg, ScopedValue):
|
|
896
|
+
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
|
897
|
+
self._normalized_args[self._normalized_args_keys[arg_index]] = arg
|
|
898
|
+
self._sync_arg()
|
|
899
|
+
|
|
900
|
+
def set_kwargs(self, kwargs: {str: ScopedValue}):
|
|
901
|
+
"""
|
|
902
|
+
Set keywords arguments of 'node'.
|
|
903
|
+
Note that when _normalized_args is updated, corresponding ast node would be updated also.
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
kwargs (dict{str: ScopedValue}): New arguments to been set.
|
|
907
|
+
|
|
908
|
+
Raises:
|
|
909
|
+
TypeError: Value of new argument is not an instance of ScopedValue.
|
|
910
|
+
RuntimeError: Length of new arguments is not equal to length of old arguments.
|
|
911
|
+
"""
|
|
912
|
+
Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Validator.INC_LEFT, "Length of kwargs")
|
|
913
|
+
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
914
|
+
for key, arg in kwargs.items():
|
|
915
|
+
if key not in self._normalized_args.keys() or key not in self._normalized_args_keys:
|
|
916
|
+
raise RuntimeError("Input key is not exist, ", key)
|
|
917
|
+
if not isinstance(arg, ScopedValue):
|
|
918
|
+
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
|
919
|
+
self._normalized_args[key] = arg
|
|
920
|
+
self._sync_arg()
|
|
921
|
+
|
|
922
|
+
def set_kwarg(self, key: str, arg: ScopedValue):
|
|
923
|
+
"""
|
|
924
|
+
Set keyword argument of 'node'.
|
|
925
|
+
Note that when _normalized_args is updated, corresponding ast node would be updated also.
|
|
926
|
+
|
|
927
|
+
Args:
|
|
928
|
+
key (str): A str represents key of new argument.
|
|
929
|
+
arg (ScopedValue): An instance of ScopedValue represents argument.
|
|
930
|
+
|
|
931
|
+
Raises:
|
|
932
|
+
RuntimeError: If 'key' is not in original kwargs' keys.
|
|
933
|
+
"""
|
|
934
|
+
if key not in self._normalized_args_keys[self._args_num:] or key not in self._normalized_args.keys():
|
|
935
|
+
raise RuntimeError("Input key is not exist, ", key)
|
|
936
|
+
self._normalized_args[key] = arg
|
|
937
|
+
self._sync_arg()
|
|
938
|
+
|
|
939
|
+
def get_args(self):
|
|
940
|
+
"""
|
|
941
|
+
Get the arguments of current node.
|
|
942
|
+
|
|
943
|
+
- When node_type of current node is CallCell, CallPrimitive or Tree, arguments are corresponding to args of
|
|
944
|
+
ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()` method.
|
|
945
|
+
- When node_type of current node is Input, arguments represents default-value of argument of function.
|
|
946
|
+
- When node_type of current node is Output, arguments represents return values.
|
|
947
|
+
- When node_type of current node is Python, arguments are don't-care.
|
|
948
|
+
|
|
949
|
+
Returns:
|
|
950
|
+
A list of instances of ScopedValue.
|
|
951
|
+
"""
|
|
952
|
+
args = []
|
|
953
|
+
for arg_index in range(self._args_num):
|
|
954
|
+
args.append(self._normalized_args.get(self._normalized_args_keys[arg_index]))
|
|
955
|
+
return args
|
|
956
|
+
|
|
957
|
+
def get_kwargs(self):
|
|
958
|
+
"""
|
|
959
|
+
Get the keyword arguments of current node.
|
|
960
|
+
|
|
961
|
+
- When node_type of current node is CallCell, CallPrimitive or Tree, keyword arguments are corresponding to
|
|
962
|
+
kwargs of ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()`
|
|
963
|
+
method.
|
|
964
|
+
- When node_type of current node is Python, Input or Output, keyword arguments are don't-care.
|
|
965
|
+
|
|
966
|
+
Returns:
|
|
967
|
+
A dict of str to instance of ScopedValue.
|
|
968
|
+
"""
|
|
969
|
+
kwargs: {str, ScopedValue} = {}
|
|
970
|
+
for arg_index in range(self._args_num, self._args_num + self._kwargs_num):
|
|
971
|
+
key = self._normalized_args_keys[arg_index]
|
|
972
|
+
kwargs[key] = self._normalized_args.get(key)
|
|
973
|
+
return kwargs
|
|
974
|
+
|
|
975
|
+
def get_normalized_args(self) -> {str: ScopedValue}:
|
|
976
|
+
"""
|
|
977
|
+
Get the normalized keyword arguments of current node.
|
|
978
|
+
Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as
|
|
979
|
+
key of arguments.
|
|
980
|
+
|
|
981
|
+
Returns:
|
|
982
|
+
A dict of str to instance of ScopedValue.
|
|
983
|
+
"""
|
|
984
|
+
output = {}
|
|
985
|
+
for key in self._normalized_args_keys:
|
|
986
|
+
output[key] = self._normalized_args.get(key)
|
|
987
|
+
return output
|
|
988
|
+
|
|
989
|
+
def set_normalized_args(self, args: {str, ScopedValue}):
|
|
990
|
+
"""
|
|
991
|
+
Set the normalized keyword arguments of current node.
|
|
992
|
+
Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as
|
|
993
|
+
key of arguments.
|
|
994
|
+
|
|
995
|
+
Args:
|
|
996
|
+
args ({str, ScopedValue}): A dict of str to instance of ScopedValue represents new normalized_args.
|
|
997
|
+
"""
|
|
998
|
+
if len(args.values()) != len(self._normalized_args_keys):
|
|
999
|
+
raise RuntimeError("Length of args.values() should be equal to length of _normalized_args_keys, ",
|
|
1000
|
+
len(args.values()), " vs ", len(self._normalized_args_keys))
|
|
1001
|
+
for key, arg in args.items():
|
|
1002
|
+
self._normalized_args[key] = arg
|
|
1003
|
+
self._sync_arg()
|
|
1004
|
+
|
|
1005
|
+
def set_attribute(self, key: str, value):
|
|
1006
|
+
"""
|
|
1007
|
+
Set attribute of current node.
|
|
1008
|
+
|
|
1009
|
+
Args:
|
|
1010
|
+
key (str): Key of new attribute.
|
|
1011
|
+
value (object): Value of new attribute.
|
|
1012
|
+
"""
|
|
1013
|
+
self._attribute[key] = value
|
|
1014
|
+
|
|
1015
|
+
def set_attributes(self, attributes):
|
|
1016
|
+
"""
|
|
1017
|
+
Set attributes of current node.
|
|
1018
|
+
|
|
1019
|
+
Args:
|
|
1020
|
+
attributes (dict): A dict represents new attributes.
|
|
1021
|
+
"""
|
|
1022
|
+
self._attribute = attributes
|
|
1023
|
+
|
|
1024
|
+
def get_attributes(self):
|
|
1025
|
+
"""
|
|
1026
|
+
Get all attributes of current node.
|
|
1027
|
+
|
|
1028
|
+
Returns:
|
|
1029
|
+
A dict of str to instance of object as attributes.
|
|
1030
|
+
"""
|
|
1031
|
+
return self._attribute
|
|
1032
|
+
|
|
1033
|
+
def get_attribute(self, key: str):
|
|
1034
|
+
"""
|
|
1035
|
+
Get attribute of current node by key.
|
|
1036
|
+
|
|
1037
|
+
Args:
|
|
1038
|
+
key (str): A str represents key of attribute you want to get.
|
|
1039
|
+
|
|
1040
|
+
Returns:
|
|
1041
|
+
A object as attribute.
|
|
1042
|
+
"""
|
|
1043
|
+
return self._attribute.get(key)
|
|
1044
|
+
|
|
1045
|
+
def get_arg_providers(self) -> dict:
|
|
1046
|
+
"""
|
|
1047
|
+
Getter of _arg_providers.
|
|
1048
|
+
|
|
1049
|
+
Return:
|
|
1050
|
+
dict, key is type of int indicating the index of args, and value is type of tuple, which includes
|
|
1051
|
+
the node and the index of node's targets who provides the argument.
|
|
1052
|
+
"""
|
|
1053
|
+
return self._arg_providers
|
|
1054
|
+
|
|
1055
|
+
def set_arg_providers(self, index: int, provider: tuple):
|
|
1056
|
+
"""
|
|
1057
|
+
Setter of _arg_providers.
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
index (int): Indicating provider of which argument need to be set.
|
|
1061
|
+
provider (tuple): A tuple includes the node and the index of node's targets who provides the argument.
|
|
1062
|
+
"""
|
|
1063
|
+
self._arg_providers[index] = provider
|
|
1064
|
+
|
|
1065
|
+
def get_target_users(self, index=-1) -> Union[dict, list]:
|
|
1066
|
+
"""
|
|
1067
|
+
Getter of _target_users.
|
|
1068
|
+
|
|
1069
|
+
Args:
|
|
1070
|
+
index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will
|
|
1071
|
+
be returned.
|
|
1072
|
+
|
|
1073
|
+
Return:
|
|
1074
|
+
Union[dict, list]. When index is not -1, a list of users of specified target will be returned.
|
|
1075
|
+
The type of elements in list is tuple, which includes the user node and the index of node's arguments
|
|
1076
|
+
who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the
|
|
1077
|
+
value is list of users of corresponding target.
|
|
1078
|
+
"""
|
|
1079
|
+
if index == -1:
|
|
1080
|
+
return self._target_users
|
|
1081
|
+
if index not in self._target_users.keys():
|
|
1082
|
+
self._target_users[index] = []
|
|
1083
|
+
return self._target_users.get(index, None)
|
|
1084
|
+
|
|
1085
|
+
def append_target_users(self, index: int, provider: tuple):
|
|
1086
|
+
"""
|
|
1087
|
+
Setter of _target_users.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
index (int): Indicating users of which target need to be append.
|
|
1091
|
+
provider (tuple): A tuple includes the node and the index of node's argument who uses the target.
|
|
1092
|
+
|
|
1093
|
+
"""
|
|
1094
|
+
if index not in self._target_users.keys():
|
|
1095
|
+
self._target_users[index] = []
|
|
1096
|
+
self._target_users.get(index).append(provider)
|
|
1097
|
+
|
|
1098
|
+
def update_ast_node(self) -> ast.AST:
|
|
1099
|
+
"""Update node's ast_node by current targets, func_name, args and kwargs."""
|
|
1100
|
+
ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(),
|
|
1101
|
+
self.get_args(), self.get_kwargs())
|
|
1102
|
+
self.set_ast(ast_assign)
|
|
1103
|
+
return ast_assign
|
|
1104
|
+
|
|
1105
|
+
def get_source_code(self) -> str:
|
|
1106
|
+
"""Get source code of node from ast of node."""
|
|
1107
|
+
return astunparse.unparse(self._ast_node).strip()
|
|
1108
|
+
|
|
1109
|
+
def append_kwarg(self, kwarg: Dict[str, ScopedValue]):
|
|
1110
|
+
"""
|
|
1111
|
+
Append a new keyword arg to node.
|
|
1112
|
+
|
|
1113
|
+
Args:
|
|
1114
|
+
kwarg (Dict[str, ScopedValue]): The new keyword arg.
|
|
1115
|
+
|
|
1116
|
+
"""
|
|
1117
|
+
if self.get_node_type() not in [NodeType.Tree, NodeType.CallFunction]:
|
|
1118
|
+
raise TypeError(f"For append_new_kwarg, the type of node can only be one of [Tree, CallFunction], "
|
|
1119
|
+
f"but got {self.get_node_type()}")
|
|
1120
|
+
Validator.check_element_type_of_dict("kwarg", kwarg, [str], [ScopedValue], "append_new_kwarg")
|
|
1121
|
+
for arg_key, value in kwarg.items():
|
|
1122
|
+
# add keyword into _normalized_args
|
|
1123
|
+
self._normalized_args[arg_key] = value
|
|
1124
|
+
self._normalized_args_keys.append(arg_key)
|
|
1125
|
+
self._kwargs_num += 1
|
|
1126
|
+
# add keyword ast into ast.Call
|
|
1127
|
+
ast_assign: ast.Assign = self._ast_node
|
|
1128
|
+
ast_call: ast.Call = ast_assign.value
|
|
1129
|
+
new_keyword = ast.keyword(arg=arg_key, value=AstModifier.get_ast_by_value(value, None))
|
|
1130
|
+
ast_call.keywords.append(new_keyword)
|
|
1131
|
+
|
|
1132
|
+
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
|
|
1133
|
+
"""
|
|
1134
|
+
Merge args and kwargs to normalized args.
|
|
1135
|
+
The keys of args are obtained from the construct function of type(self._instance).
|
|
1136
|
+
|
|
1137
|
+
Args:
|
|
1138
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1139
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1140
|
+
|
|
1141
|
+
Raises:
|
|
1142
|
+
RuntimeError: Input args are invalid.
|
|
1143
|
+
RuntimeError: Arg name already exist in kwargs.
|
|
1144
|
+
|
|
1145
|
+
Returns:
|
|
1146
|
+
The normalized args.
|
|
1147
|
+
"""
|
|
1148
|
+
if not args:
|
|
1149
|
+
args = []
|
|
1150
|
+
if not kwargs:
|
|
1151
|
+
kwargs = {}
|
|
1152
|
+
normalized_args: dict = dict()
|
|
1153
|
+
if (args or kwargs) and self._instance and hasattr(type(self._instance), "construct"):
|
|
1154
|
+
parameters = inspect.signature(type(self._instance).construct).parameters
|
|
1155
|
+
names = Node._get_construct_arg_names(parameters)
|
|
1156
|
+
Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
|
|
1157
|
+
else:
|
|
1158
|
+
logger.debug("fail to get arg name from op, using arg_xx for args' name")
|
|
1159
|
+
arg_temp_name, suffix = "arg", 0
|
|
1160
|
+
for arg in args:
|
|
1161
|
+
arg_key = "{}_{}".format(arg_temp_name, suffix)
|
|
1162
|
+
while arg_key in kwargs.keys() or arg_key in normalized_args.keys():
|
|
1163
|
+
suffix += 1
|
|
1164
|
+
arg_key = "{}_{}".format(arg_temp_name, suffix)
|
|
1165
|
+
normalized_args[arg_key] = arg
|
|
1166
|
+
self._normalized_args_keys.append(arg_key)
|
|
1167
|
+
for arg_key, value in kwargs.items():
|
|
1168
|
+
normalized_args[arg_key] = value
|
|
1169
|
+
self._normalized_args_keys.append(arg_key)
|
|
1170
|
+
return normalized_args
|
|
1171
|
+
|
|
1172
|
+
# Synchronize rewrite node args to ast node
|
|
1173
|
+
def _sync_assign_func_name_to_ast(self):
|
|
1174
|
+
"""Sync func_name of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
|
|
1175
|
+
if self._ast_node is None:
|
|
1176
|
+
return
|
|
1177
|
+
assign_ast = self._ast_node
|
|
1178
|
+
if not isinstance(assign_ast, ast.Assign):
|
|
1179
|
+
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
|
1180
|
+
call_ast = assign_ast.value
|
|
1181
|
+
if not isinstance(call_ast, ast.Call):
|
|
1182
|
+
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
|
|
1183
|
+
if self._func_name.type == ValueType.UnsupportedValue:
|
|
1184
|
+
return
|
|
1185
|
+
func_ast = call_ast.func
|
|
1186
|
+
if not self._func_name.scope:
|
|
1187
|
+
if isinstance(func_ast, ast.Name):
|
|
1188
|
+
func_ast.id = self._func_name.value
|
|
1189
|
+
else:
|
|
1190
|
+
call_ast.func = ast.Name(self._func_name.value, ast.Store())
|
|
1191
|
+
else:
|
|
1192
|
+
if isinstance(func_ast, ast.Attribute):
|
|
1193
|
+
if not isinstance(func_ast.value, ast.Name):
|
|
1194
|
+
func_ast.value = ast.Name(self._func_name.scope, ast.Load())
|
|
1195
|
+
else:
|
|
1196
|
+
func_ast.value.id = self._func_name.scope
|
|
1197
|
+
func_ast.attr = self._func_name.value
|
|
1198
|
+
else:
|
|
1199
|
+
call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
|
|
1200
|
+
self._func_name.value, ast.Store())
|
|
1201
|
+
ast.fix_missing_locations(assign_ast)
|
|
1202
|
+
|
|
1203
|
+
def _sync_assign_targets_to_ast(self):
|
|
1204
|
+
"""Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod."""
|
|
1205
|
+
if self._ast_node is None:
|
|
1206
|
+
return
|
|
1207
|
+
assign_ast = self._ast_node
|
|
1208
|
+
if not isinstance(assign_ast, ast.Assign):
|
|
1209
|
+
raise TypeError(error_str(f"assign_ast should be ast.Assign, but got: {type(assign_ast)}",
|
|
1210
|
+
father_node=assign_ast))
|
|
1211
|
+
# update targets
|
|
1212
|
+
target_ast_elems = AstConverter.get_ast_target_elems(assign_ast.targets[0])
|
|
1213
|
+
if len(self._targets) != len(target_ast_elems):
|
|
1214
|
+
raise ValueError(error_str(f"The number of targets should be {len(target_ast_elems)}, "
|
|
1215
|
+
f"but got {len(self._targets)}", father_node=assign_ast))
|
|
1216
|
+
for i, target_ast in enumerate(target_ast_elems):
|
|
1217
|
+
target_ast_elems[i] = AstModifier.get_ast_by_value(self._targets[i], target_ast)
|
|
1218
|
+
|
|
1219
|
+
def _sync_call_args_to_ast(self):
|
|
1220
|
+
"""Sync args of ast.Call from self._normalized_args."""
|
|
1221
|
+
if self._ast_node is None:
|
|
1222
|
+
return
|
|
1223
|
+
assign_ast = self._ast_node
|
|
1224
|
+
if not isinstance(assign_ast, ast.Assign):
|
|
1225
|
+
raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node should be "
|
|
1226
|
+
f"ast.Assign, but got: {type(assign_ast)}")
|
|
1227
|
+
assign_value = assign_ast.value
|
|
1228
|
+
if not isinstance(assign_value, ast.Call):
|
|
1229
|
+
if isinstance(assign_value, ast.Attribute) and self._node_type in (NodeType.CellContainer,
|
|
1230
|
+
NodeType.CallCell):
|
|
1231
|
+
# CellContainers in control flow may be flatten to ast.Attribute: blocks_var = self.blocks
|
|
1232
|
+
# In this case, no args exist in node, so we don't need to sync.
|
|
1233
|
+
# CellContainers may be type of CallCell when share one implementation
|
|
1234
|
+
return
|
|
1235
|
+
raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node.value should "
|
|
1236
|
+
f"be ast.Call, but got: {type(assign_value)}")
|
|
1237
|
+
keywords_ast = assign_value.keywords
|
|
1238
|
+
args_ast = assign_value.args
|
|
1239
|
+
if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
|
|
1240
|
+
raise ValueError("ast keywords plus args len is not equal to self._normalized_args value")
|
|
1241
|
+
for arg_index in range(self._args_num):
|
|
1242
|
+
arg_ast = args_ast[arg_index]
|
|
1243
|
+
args_ast[arg_index] = \
|
|
1244
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
|
|
1245
|
+
|
|
1246
|
+
# the order of kwargs may not the same as that in keywords_ast
|
|
1247
|
+
keyword_map_index = {}
|
|
1248
|
+
for index, keyword_ast in enumerate(keywords_ast):
|
|
1249
|
+
keyword_map_index[keyword_ast.arg] = index
|
|
1250
|
+
for keyword_index in range(self._kwargs_num):
|
|
1251
|
+
key = self._normalized_args_keys[keyword_index + self._args_num]
|
|
1252
|
+
keywords_ast[keyword_map_index.get(key)].value = \
|
|
1253
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(key),
|
|
1254
|
+
keywords_ast[keyword_map_index.get(key)].value)
|
|
1255
|
+
|
|
1256
|
+
def _sync_call_method_args_to_ast(self):
|
|
1257
|
+
"""
|
|
1258
|
+
Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
|
|
1259
|
+
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1260
|
+
| func_name | data_type | value of ast.Assign |
|
|
1261
|
+
|:---------------|:------------|:------------------------|
|
|
1262
|
+
| 'pass_through' | constants | ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str |
|
|
1263
|
+
| 'pass_through' | variables | ast.Name, ast.Attribute |
|
|
1264
|
+
| 'tuple' | tuple | ast.Tuple |
|
|
1265
|
+
| 'list' | list | ast.List |
|
|
1266
|
+
| 'dict' | dict | ast.Dict |
|
|
1267
|
+
"""
|
|
1268
|
+
if self._ast_node is None:
|
|
1269
|
+
return
|
|
1270
|
+
assign_ast = self._ast_node
|
|
1271
|
+
if not isinstance(assign_ast, ast.Assign):
|
|
1272
|
+
raise TypeError(f"For node '{self.get_name()}', assign_ast should be ast.Assign, got: ", type(assign_ast))
|
|
1273
|
+
assign_value = assign_ast.value
|
|
1274
|
+
if self._func_name.value == "pass_through":
|
|
1275
|
+
# update constants/variables
|
|
1276
|
+
assign_ast.value = \
|
|
1277
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), assign_value)
|
|
1278
|
+
elif self._func_name.value in ("tuple", "list", "dict"):
|
|
1279
|
+
# update tuple/list/dict
|
|
1280
|
+
ast_elts = assign_value.values if isinstance(assign_value, ast.Dict) else assign_value.elts
|
|
1281
|
+
if len(self._normalized_args_keys) != len(ast_elts):
|
|
1282
|
+
raise ValueError(f"For node '{self.get_name()}', size of self._normalized_args_keys"
|
|
1283
|
+
f"({len(self._normalized_args_keys)}) should be equal to size of elements of "
|
|
1284
|
+
f"ast_elts({len(ast_elts)})")
|
|
1285
|
+
for index, elt in enumerate(ast_elts):
|
|
1286
|
+
scoped_value: ScopedValue = self._normalized_args.get(self._normalized_args_keys[index])
|
|
1287
|
+
ast_elts[index] = AstModifier.get_ast_by_value(scoped_value, elt)
|
|
1288
|
+
else:
|
|
1289
|
+
raise TypeError(f"For node '{self.get_name()}', only support (pass_through, tuple or dict method) as "
|
|
1290
|
+
f"call_method, but got {self._func_name.value}")
|
|
1291
|
+
|
|
1292
|
+
def _sync_return_node_to_ast(self):
|
|
1293
|
+
"""
|
|
1294
|
+
Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
|
|
1295
|
+
|
|
1296
|
+
For node with type of CallMethod, the value of ast.Assign is one of:
|
|
1297
|
+
(ast.Name, ast.Attribute)
|
|
1298
|
+
"""
|
|
1299
|
+
if self._ast_node is None:
|
|
1300
|
+
return
|
|
1301
|
+
return_ast = self._ast_node
|
|
1302
|
+
if not isinstance(return_ast, ast.Return):
|
|
1303
|
+
raise TypeError(f"For node '{self.get_name()}', return_ast should be ast.Return, got: {type(return_ast)}")
|
|
1304
|
+
return_value_ast = return_ast.value
|
|
1305
|
+
return_ast.value = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
|
|
1306
|
+
return_value_ast)
|
|
1307
|
+
|
|
1308
|
+
def _sync_mathops_node_args_to_ast(self):
|
|
1309
|
+
"""
|
|
1310
|
+
Sync values from self._normalized_args to the ast node for mathematical operations.
|
|
1311
|
+
"""
|
|
1312
|
+
if self._ast_node is None:
|
|
1313
|
+
return
|
|
1314
|
+
if not isinstance(self._ast_node, ast.Assign):
|
|
1315
|
+
raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
|
|
1316
|
+
mathops_node = self._ast_node.value
|
|
1317
|
+
if isinstance(mathops_node, ast.BinOp):
|
|
1318
|
+
left = mathops_node.left
|
|
1319
|
+
right = mathops_node.right
|
|
1320
|
+
mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
|
|
1321
|
+
left)
|
|
1322
|
+
mathops_node.right = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[1]),
|
|
1323
|
+
right)
|
|
1324
|
+
elif isinstance(mathops_node, ast.UnaryOp):
|
|
1325
|
+
operand = mathops_node.operand
|
|
1326
|
+
mathops_node.operand = \
|
|
1327
|
+
AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
|
|
1328
|
+
elif isinstance(mathops_node, ast.BoolOp):
|
|
1329
|
+
values = mathops_node.values
|
|
1330
|
+
for arg_index in range(self._args_num):
|
|
1331
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1332
|
+
values[arg_index] = AstModifier.get_ast_by_value(arg_value, values[arg_index])
|
|
1333
|
+
elif isinstance(mathops_node, ast.Compare):
|
|
1334
|
+
left = mathops_node.left
|
|
1335
|
+
mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]),
|
|
1336
|
+
left)
|
|
1337
|
+
comparators = mathops_node.comparators
|
|
1338
|
+
for arg_index in range(1, self._args_num):
|
|
1339
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1340
|
+
comparators[arg_index - 1] = AstModifier.get_ast_by_value(arg_value, comparators[arg_index - 1])
|
|
1341
|
+
else:
|
|
1342
|
+
raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
1343
|
+
"ast.BoolOp, ast.Compare), but got ", type(mathops_node))
|
|
1344
|
+
|
|
1345
|
+
def _sync_control_flow_args_to_ast(self):
|
|
1346
|
+
"""
|
|
1347
|
+
Sync values from self._normalized_args to the ast node of control flow.
|
|
1348
|
+
"""
|
|
1349
|
+
if self._ast_node is None:
|
|
1350
|
+
return
|
|
1351
|
+
normalized_args_num = len(self._normalized_args_keys)
|
|
1352
|
+
if normalized_args_num == 0:
|
|
1353
|
+
return
|
|
1354
|
+
if normalized_args_num > 1:
|
|
1355
|
+
raise ValueError("self._normalized_args_keys should have less than 1 elements")
|
|
1356
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[0])
|
|
1357
|
+
if isinstance(self._ast_node, (ast.If, ast.IfExp, ast.While)):
|
|
1358
|
+
self._ast_node.test = AstModifier.get_ast_by_value(arg_value, self._ast_node.test)
|
|
1359
|
+
elif isinstance(self._ast_node, ast.For):
|
|
1360
|
+
self._ast_node.iter = AstModifier.get_ast_by_value(arg_value, self._ast_node.iter)
|
|
1361
|
+
else:
|
|
1362
|
+
raise ValueError(f"For Control Flow, ast_node should be one of [ast.If, ast.IfExp, "
|
|
1363
|
+
f"ast.While, ast.For], but got {type(self._ast_node)}")
|
|
1364
|
+
|
|
1365
|
+
def _sync_arg(self):
|
|
1366
|
+
"""Sync _normalized_args to corresponding ast node when updated."""
|
|
1367
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
|
|
1368
|
+
NodeType.CellContainer, NodeType.CallFunction):
|
|
1369
|
+
self._sync_call_args_to_ast()
|
|
1370
|
+
elif self._node_type == NodeType.Output:
|
|
1371
|
+
self._sync_return_node_to_ast()
|
|
1372
|
+
elif self._node_type == NodeType.CallMethod:
|
|
1373
|
+
self._sync_call_method_args_to_ast()
|
|
1374
|
+
elif self._node_type == NodeType.MathOps:
|
|
1375
|
+
self._sync_mathops_node_args_to_ast()
|
|
1376
|
+
elif self._node_type == NodeType.ControlFlow:
|
|
1377
|
+
self._sync_control_flow_args_to_ast()
|
|
1378
|
+
|
|
1379
|
+
|
|
1380
|
+
# Child classes
|
|
1381
|
+
class TreeNode(Node):
|
|
1382
|
+
"""Tree type Node who holds a handler of SymbolTree."""
|
|
1383
|
+
|
|
1384
|
+
def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
|
1385
|
+
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
|
1386
|
+
"""
|
|
1387
|
+
Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of
|
|
1388
|
+
TreeNode such as `create_tree_node` rather than invoking constructor of Node directly.
|
|
1389
|
+
|
|
1390
|
+
Args:
|
|
1391
|
+
tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
|
|
1392
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1393
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1394
|
+
func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1395
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1396
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1397
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1398
|
+
Name of node also used as field name in network class.
|
|
1399
|
+
instance: Object in network corresponding to this node.
|
|
1400
|
+
"""
|
|
1401
|
+
if isinstance(func, str):
|
|
1402
|
+
func = ScopedValue.create_naming_value(func)
|
|
1403
|
+
super().__init__(NodeType.Tree, ast_node, targets, func, args, kwargs, name, instance)
|
|
1404
|
+
self.symbol_tree = tree
|
|
1405
|
+
|
|
1406
|
+
@classmethod
|
|
1407
|
+
def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
|
|
1408
|
+
func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue},
|
|
1409
|
+
name: str = "", instance=None):
|
|
1410
|
+
"""
|
|
1411
|
+
Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking
|
|
1412
|
+
to sub-network.
|
|
1413
|
+
|
|
1414
|
+
Args:
|
|
1415
|
+
tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
|
|
1416
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1417
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1418
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1419
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1420
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1421
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1422
|
+
Name of node also used as field name in network class.
|
|
1423
|
+
instance: Object in network corresponding to this node.
|
|
1424
|
+
"""
|
|
1425
|
+
new_targets = Node._handle_targets(targets)
|
|
1426
|
+
if isinstance(func_name, str):
|
|
1427
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
1428
|
+
return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance)
|