mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,852 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
16
|
+
from typing import Union, List, Dict
|
|
17
|
+
import types
|
|
18
|
+
import os
|
|
19
|
+
import ast
|
|
20
|
+
import sys
|
|
21
|
+
import inspect
|
|
22
|
+
import builtins
|
|
23
|
+
from textwrap import dedent
|
|
24
|
+
|
|
25
|
+
from mindspore import log as logger
|
|
26
|
+
from mindspore.nn import Cell, SequentialCell, CellList
|
|
27
|
+
from mindspore.ops.primitive import Primitive
|
|
28
|
+
import mindspore.ops.functional as F
|
|
29
|
+
from . import Parser, ParserRegister, reg_parser
|
|
30
|
+
from ..symbol_tree import SymbolTree
|
|
31
|
+
from ..node import Node, TreeNode, NodeManager, CallFunction, CellContainer, ControlFlow, LocalPrim
|
|
32
|
+
from ..api.scoped_value import ScopedValue
|
|
33
|
+
from ..ast_helpers import AstFlattener, AstConverter, AstFinder
|
|
34
|
+
from ..common.error_log import error_str
|
|
35
|
+
from ..common.namespace import is_subtree, is_ms_function, is_third_party
|
|
36
|
+
from ..common.namer import FunctionNamer
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if sys.version_info >= (3, 9):
|
|
40
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
41
|
+
else:
|
|
42
|
+
import astunparse
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class AssignParser(Parser):
|
|
46
|
+
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
47
|
+
|
|
48
|
+
# Types for creating Cell Container node
|
|
49
|
+
types_for_cell_container = [SequentialCell,]
|
|
50
|
+
# If mindspore built-in function to be parsered or skipped
|
|
51
|
+
_skip_ms_function = False
|
|
52
|
+
# Functions in black list will not be parsed
|
|
53
|
+
_function_parse_black_list = [F.arange]
|
|
54
|
+
# Share one implementation for the same instances
|
|
55
|
+
_share_one_implementation = False
|
|
56
|
+
# Implementation caches of sub SymbolTrees, CallFunction nodes and CellContainer nodes
|
|
57
|
+
# Keys are ids of the instance object
|
|
58
|
+
_cached_trees: Dict[int, SymbolTree] = {}
|
|
59
|
+
_cached_functions: Dict[int, Node] = {}
|
|
60
|
+
_cached_cell_containers: Dict[int, Node] = {}
|
|
61
|
+
|
|
62
|
+
def __init__(self):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self._variables_cache = []
|
|
65
|
+
self.stree: SymbolTree = None
|
|
66
|
+
self.ast_assign: ast.Assign = None
|
|
67
|
+
self.node_manager: NodeManager = None
|
|
68
|
+
self.targets: List[ScopedValue] = None
|
|
69
|
+
self.args: List[ScopedValue] = None
|
|
70
|
+
self.kwargs: Dict[str, ScopedValue] = None
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _get_func_name(ast_call: ast.Call) -> str:
|
|
74
|
+
"""
|
|
75
|
+
Get the func name from ast.Call.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
ast_call (ast.Call): Input ast.Call node.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Func name.
|
|
82
|
+
"""
|
|
83
|
+
func = ast_call.func
|
|
84
|
+
if isinstance(func, ast.Name):
|
|
85
|
+
return func.id
|
|
86
|
+
if isinstance(func, ast.Attribute):
|
|
87
|
+
return func.attr
|
|
88
|
+
func_full_name = astunparse.unparse(func).strip()
|
|
89
|
+
if func_full_name.count('.') > 0:
|
|
90
|
+
return func_full_name.split('.')[-1]
|
|
91
|
+
return func_full_name
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def _get_func_scope(ast_call: ast.Call) -> str:
|
|
95
|
+
"""
|
|
96
|
+
Get the func scope from ast.Call.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
ast_call (ast.Call): Input ast.Call node.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Func scope.
|
|
103
|
+
"""
|
|
104
|
+
func = ast_call.func
|
|
105
|
+
if isinstance(func, ast.Name):
|
|
106
|
+
return ""
|
|
107
|
+
func_full_name = astunparse.unparse(func).strip()
|
|
108
|
+
if func_full_name.count('.') > 0:
|
|
109
|
+
return func_full_name.rsplit('.', 1)[0]
|
|
110
|
+
return ""
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _create_targets(ast_target: ast.AST) -> List[ScopedValue]:
|
|
114
|
+
"""Get targets from ast node."""
|
|
115
|
+
ast_target_elems = AstConverter.get_ast_target_elems(ast_target)
|
|
116
|
+
targets = [AstConverter.create_scopedvalue(ast_node) for ast_node in ast_target_elems]
|
|
117
|
+
return targets
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _create_kwargs(keywords: [ast.keyword]) -> Dict[str, ScopedValue]:
|
|
121
|
+
"""
|
|
122
|
+
Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
keywords ([ast.keyword]): Keywords of ast.Call node.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
A dict of ScopedValue.
|
|
129
|
+
"""
|
|
130
|
+
results = {}
|
|
131
|
+
for keyword in keywords:
|
|
132
|
+
results[keyword.arg] = AstConverter.create_scopedvalue(keyword.value)
|
|
133
|
+
return results
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _get_inst_and_name(ast_node: ast.Attribute, stree: SymbolTree):
|
|
138
|
+
"""
|
|
139
|
+
Try to get instance object of ast_node from ast.Attribute.
|
|
140
|
+
"""
|
|
141
|
+
if not isinstance(ast_node, ast.Attribute):
|
|
142
|
+
return None, ""
|
|
143
|
+
scope_name = astunparse.unparse(ast_node).strip()
|
|
144
|
+
scope, name = scope_name.split('.', 1)
|
|
145
|
+
if scope != 'self':
|
|
146
|
+
return None, scope_name
|
|
147
|
+
if not hasattr(stree.get_origin_network(), name):
|
|
148
|
+
return None, scope_name
|
|
149
|
+
return getattr(stree.get_origin_network(), name), scope_name
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def _list_of_cells(cell_list: list):
|
|
153
|
+
"""Check if elements in the list are all cells."""
|
|
154
|
+
for item in cell_list:
|
|
155
|
+
if not isinstance(item, Cell):
|
|
156
|
+
return False
|
|
157
|
+
return True
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _get_path_of_node_manager(node_manager: NodeManager):
|
|
161
|
+
"""Get file path of type(instance) in NodeManager"""
|
|
162
|
+
node_manager = node_manager.get_top_manager()
|
|
163
|
+
if isinstance(node_manager, SymbolTree):
|
|
164
|
+
return inspect.getfile(type(node_manager.get_origin_network()))
|
|
165
|
+
return inspect.getfile(node_manager.get_instance())
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _get_module_of_node_manager(node_manager: NodeManager):
|
|
169
|
+
"""Get module where the node manager is located"""
|
|
170
|
+
# get module where function object is used
|
|
171
|
+
func_path = AssignParser._get_path_of_node_manager(node_manager)
|
|
172
|
+
func_path = os.path.normcase(os.path.normpath(func_path))
|
|
173
|
+
modules = list(sys.modules.values())
|
|
174
|
+
for m in modules:
|
|
175
|
+
if hasattr(m, "__file__") and m.__file__ is not None and func_path == os.path.normcase(m.__file__):
|
|
176
|
+
return m, func_path
|
|
177
|
+
return None, func_path
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
def _get_object_from_module(func_full_name: str, module: types.ModuleType):
|
|
181
|
+
"""Get object from module according to full name of function"""
|
|
182
|
+
names = func_full_name.split('.')
|
|
183
|
+
obj = module
|
|
184
|
+
for attr in names:
|
|
185
|
+
if not hasattr(obj, attr):
|
|
186
|
+
logger.info(f"For '{func_full_name}', failed to get attr '{attr}' from '{obj}'")
|
|
187
|
+
return None
|
|
188
|
+
obj = getattr(obj, attr)
|
|
189
|
+
return obj
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def _get_local_var_provider(node_manager: NodeManager, var: str) -> Node:
|
|
193
|
+
"""Get the node providing specific variable"""
|
|
194
|
+
node = node_manager.get_tail()
|
|
195
|
+
while node is not None:
|
|
196
|
+
if var in [str(target) for target in node.get_targets()]:
|
|
197
|
+
return node
|
|
198
|
+
node = node.get_prev()
|
|
199
|
+
# When node_manager is control flow, nodes in upper node_manager need to be traversed.
|
|
200
|
+
if isinstance(node_manager, ControlFlow):
|
|
201
|
+
return AssignParser._get_local_var_provider(node_manager.get_node_manager(), var)
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
def target(self):
|
|
205
|
+
"""Parse target type."""
|
|
206
|
+
return ast.Assign
|
|
207
|
+
|
|
208
|
+
def store_env(self):
|
|
209
|
+
"""Store current environments"""
|
|
210
|
+
self._variables_cache.append(
|
|
211
|
+
[self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs])
|
|
212
|
+
self.stree = None
|
|
213
|
+
self.ast_assign = None
|
|
214
|
+
self.node_manager = None
|
|
215
|
+
self.targets = None
|
|
216
|
+
self.args = None
|
|
217
|
+
self.kwargs = None
|
|
218
|
+
|
|
219
|
+
def restore_env(self):
|
|
220
|
+
"""Restore last environments"""
|
|
221
|
+
self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs = \
|
|
222
|
+
self._variables_cache.pop()
|
|
223
|
+
|
|
224
|
+
def _get_cell_instance(self, func_scope, func_name):
|
|
225
|
+
"""
|
|
226
|
+
Get object instance from ast.Call with type of Cell.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
func_scope (str): Func scope.
|
|
230
|
+
func_name (str): Func name.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
An instance represents operator instance.
|
|
234
|
+
"""
|
|
235
|
+
if func_scope != "self":
|
|
236
|
+
return None
|
|
237
|
+
var_dict = self.stree.get_origin_network().__dict__
|
|
238
|
+
# Instance is of type Cell
|
|
239
|
+
for key, value in var_dict["_cells"].items():
|
|
240
|
+
if key == func_name:
|
|
241
|
+
return value
|
|
242
|
+
# Instance is of other type.
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
def _get_primitive_instance(self, func_scope, func_name):
|
|
246
|
+
"""
|
|
247
|
+
Get object instance from ast.Call with type of Primitive.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
func_scope (str): Func scope.
|
|
251
|
+
func_name (str): Func name.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
An instance represents operator instance.
|
|
255
|
+
"""
|
|
256
|
+
if func_scope != "self":
|
|
257
|
+
return None
|
|
258
|
+
var_dict = self.stree.get_origin_network().__dict__
|
|
259
|
+
# Instance is of type Primitive
|
|
260
|
+
for key, value in var_dict["_primitives"].items():
|
|
261
|
+
if key == func_name:
|
|
262
|
+
return value
|
|
263
|
+
# Instance is of other type.
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
def _get_method_object(self, func_scope, func_name):
|
|
267
|
+
"""Get method object from network instance."""
|
|
268
|
+
stree = self.stree
|
|
269
|
+
if func_scope in ('self', stree.get_opt_cls_name()) and hasattr(stree.get_origin_network(), func_name):
|
|
270
|
+
return getattr(stree.get_origin_network(), func_name)
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
def _get_local_variable(self, func_scope, func_name) -> (bool, object):
|
|
274
|
+
"""
|
|
275
|
+
Get local variable
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
func_scope (str): Func scope.
|
|
279
|
+
func_name (str): Func name.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
bool: Indicate whether local variable is found.
|
|
283
|
+
object (Union[LocalPrim, type]): Instance of LocalPrim when calling the class, or class type
|
|
284
|
+
object when initializing the class.
|
|
285
|
+
"""
|
|
286
|
+
func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name
|
|
287
|
+
# try to find func_name in class variables initializing the primitive during forward method
|
|
288
|
+
provider_node = None
|
|
289
|
+
if func_scope == "self":
|
|
290
|
+
for node in self.stree.local_prim_inits():
|
|
291
|
+
if func_full_name in [str(target) for target in node.get_targets()]:
|
|
292
|
+
provider_node = node
|
|
293
|
+
# try to find func_name in local variables
|
|
294
|
+
if provider_node is None:
|
|
295
|
+
provider_node = AssignParser._get_local_var_provider(self.node_manager, func_full_name)
|
|
296
|
+
if provider_node:
|
|
297
|
+
# when the node providering the local variable initialized a primitive during forward method,
|
|
298
|
+
# we use LocalPrim to indicate the instance of this primitive. e.g. :
|
|
299
|
+
# abs_inst = P.Abs() -> 'abs_inst' is an instance of primitive initialized locally
|
|
300
|
+
# y = abs_inst(x) -> here we are parsing now
|
|
301
|
+
cls_init = provider_node.get_init_cls()
|
|
302
|
+
if cls_init and inspect.isclass(cls_init) and issubclass(cls_init, Primitive):
|
|
303
|
+
return True, LocalPrim(cls_init)
|
|
304
|
+
# when the node providering the local variable represent a primitive type object, we return
|
|
305
|
+
# type-object to indicate that we are initializing this primitive. e.g. :
|
|
306
|
+
# abs_ops = _get_cache_prim(P.Abs) -> 'abs_ops' is an primitive type object
|
|
307
|
+
# y = abs_ops(x) -> here we are parsing now
|
|
308
|
+
cls_type = provider_node.get_type_cls()
|
|
309
|
+
if cls_type and inspect.isclass(cls_type) and issubclass(cls_type, Primitive):
|
|
310
|
+
return True, cls_type
|
|
311
|
+
# local variable whose type is not primitive instance
|
|
312
|
+
logger.info(f"Ignore local variable: {func_full_name}")
|
|
313
|
+
return True, None
|
|
314
|
+
# other local variable
|
|
315
|
+
if AssignParser._get_local_var_provider(self.node_manager, func_full_name.split('.')[0]):
|
|
316
|
+
logger.info(f"Ignore local variable: {func_full_name}")
|
|
317
|
+
return True, None
|
|
318
|
+
return False, None
|
|
319
|
+
|
|
320
|
+
def _get_function_object(self, func_scope, func_name, ast_call) -> (object, bool):
|
|
321
|
+
"""
|
|
322
|
+
Get function object from module.
|
|
323
|
+
|
|
324
|
+
If the code represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs),
|
|
325
|
+
return primitive type object with class type flag True.
|
|
326
|
+
|
|
327
|
+
if the code represent an initializtion of a class, e.g. abs_inst = P.Abs(),
|
|
328
|
+
return primitive type object with class type flag False.
|
|
329
|
+
|
|
330
|
+
if the code represent the call of function or class instance, e.g. y = abs_inst(x)/func(x),
|
|
331
|
+
return primitive instance or function object with class type flag False.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
func_scope (str): Func scope.
|
|
335
|
+
func_name (str): Func name.
|
|
336
|
+
ast_call (ast.Call): ast.Call of ast.Assign.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
object: Class type object, class instance or function object
|
|
340
|
+
bool: Flag indicate is node represent a class type object.
|
|
341
|
+
"""
|
|
342
|
+
func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name
|
|
343
|
+
# get module where function object is used
|
|
344
|
+
module, func_path = AssignParser._get_module_of_node_manager(self.node_manager)
|
|
345
|
+
if module is None:
|
|
346
|
+
logger.debug(f"When getting object of '{func_full_name}', failed to find module in '{func_path}'")
|
|
347
|
+
return None, False
|
|
348
|
+
# if name of function is _get_cache_prim, return primitive type object
|
|
349
|
+
is_cls_type_obj = False
|
|
350
|
+
if func_full_name == '_get_cache_prim':
|
|
351
|
+
func_full_name = astunparse.unparse(ast_call.args[0]).strip()
|
|
352
|
+
is_cls_type_obj = True
|
|
353
|
+
# find object in module
|
|
354
|
+
obj = AssignParser._get_object_from_module(func_full_name, module)
|
|
355
|
+
return obj, is_cls_type_obj
|
|
356
|
+
|
|
357
|
+
def _update_field_in_init(self, func_name: str, sub_tree: SymbolTree) -> bool:
|
|
358
|
+
"""
|
|
359
|
+
When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
|
|
360
|
+
Add the code like: `self.field = SubNetwork(self.field)`
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
func_name (str): A string represents scope and name of function symbol.
|
|
364
|
+
sub_tree (SymbolTree): The SymbolTree corresponding to sub-network.
|
|
365
|
+
"""
|
|
366
|
+
init_func_ast = self.stree.get_init_func_ast()
|
|
367
|
+
sub_net_obj = sub_tree.get_origin_network()
|
|
368
|
+
sub_net_opt_name = sub_tree.get_opt_cls_name()
|
|
369
|
+
# Add .to_float(mindspore.float16) if origin subnet has this attribute
|
|
370
|
+
new_code = f"{func_name} = {sub_net_opt_name}({func_name})"
|
|
371
|
+
if hasattr(sub_net_obj, "fp16") and sub_net_obj.fp16:
|
|
372
|
+
new_code = f"{new_code}.to_float(mindspore.float16)"
|
|
373
|
+
elif hasattr(sub_net_obj, "bf16") and sub_net_obj.bf16:
|
|
374
|
+
new_code = f"{new_code}.to_float(mindspore.bfloat16)"
|
|
375
|
+
new_ast = ast.parse(new_code).body[0]
|
|
376
|
+
init_func_ast.body.append(new_ast)
|
|
377
|
+
|
|
378
|
+
def _update_cell_container_in_init(self, container_name, container_idx, subnet_opt_name):
|
|
379
|
+
"""
|
|
380
|
+
When nn.SequentialCell include sub-symboltree, the new class definition will be used to create object.
|
|
381
|
+
So the assign code will be got from origin code first, and then be modified to new class name.
|
|
382
|
+
|
|
383
|
+
Codes like:
|
|
384
|
+
|
|
385
|
+
`self.container = nn.SequentialCell([ReLU(), MyNet()])`
|
|
386
|
+
|
|
387
|
+
will be updated by add codes:
|
|
388
|
+
|
|
389
|
+
`self.container[1] = MyNetOpt(self.container[1])`
|
|
390
|
+
|
|
391
|
+
"""
|
|
392
|
+
new_code = f"{container_name}[{container_idx}] = {subnet_opt_name}({container_name}[{container_idx}])"
|
|
393
|
+
new_ast = ast.parse(new_code).body[0]
|
|
394
|
+
self.stree.get_init_func_ast().body.append(new_ast)
|
|
395
|
+
|
|
396
|
+
def _add_import(self, import_name: str):
|
|
397
|
+
""" add import to current node manager."""
|
|
398
|
+
module, _ = AssignParser._get_module_of_node_manager(self.node_manager)
|
|
399
|
+
if module is None:
|
|
400
|
+
logger.info(f"Cannot get module where '{import_name}' is located, ignore import info")
|
|
401
|
+
return
|
|
402
|
+
node_manager = self.node_manager.get_top_manager()
|
|
403
|
+
belonging_ast = None if isinstance(node_manager, SymbolTree) else node_manager.get_manager_ast()
|
|
404
|
+
self.stree.add_import(module, import_name, belonging_ast)
|
|
405
|
+
|
|
406
|
+
def cell_container_process(self, func_name: str, node_name: str, container_obj: object):
|
|
407
|
+
""" parse cell container object."""
|
|
408
|
+
# create unparsable node if container is already parsed when sharing one implementation
|
|
409
|
+
if AssignParser._share_one_implementation and id(container_obj) in AssignParser._cached_cell_containers:
|
|
410
|
+
cell_container = Node.create_call_buildin_op(container_obj, self.ast_assign, self.targets,
|
|
411
|
+
func_name, self.args, self.kwargs, node_name)
|
|
412
|
+
return cell_container
|
|
413
|
+
cell_container = CellContainer(self.ast_assign, self.targets, func_name, self.args, self.kwargs,
|
|
414
|
+
node_name, self.stree, container_obj)
|
|
415
|
+
for i, cell in enumerate(container_obj):
|
|
416
|
+
cell_name = type(cell).__name__
|
|
417
|
+
# The type of cell is container of cells (e.g. SequentialCell)
|
|
418
|
+
if isinstance(cell, tuple(AssignParser.types_for_cell_container)):
|
|
419
|
+
sub_node = self.cell_container_process(f"{func_name}[{i}]", cell_name, cell)
|
|
420
|
+
elif is_subtree(cell):
|
|
421
|
+
# create unparsable node if tree node is already parsed when sharing one implementation
|
|
422
|
+
if AssignParser._share_one_implementation and id(cell) in AssignParser._cached_trees:
|
|
423
|
+
first_stree = AssignParser._cached_trees.get(id(cell))
|
|
424
|
+
self._update_cell_container_in_init(func_name, i, first_stree.get_opt_cls_name())
|
|
425
|
+
sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args,
|
|
426
|
+
self.kwargs, cell_name)
|
|
427
|
+
else:
|
|
428
|
+
from ..symbol_tree import SymbolTreeBuilder
|
|
429
|
+
stb = SymbolTreeBuilder(cell)
|
|
430
|
+
new_stree = stb.build()
|
|
431
|
+
sub_node = TreeNode.create_tree_node(new_stree, None, self.targets, cell_name, self.args,
|
|
432
|
+
self.kwargs, cell_name, cell)
|
|
433
|
+
self._update_cell_container_in_init(func_name, i, new_stree.get_opt_cls_name())
|
|
434
|
+
# save symbol tree if it is firstly parsed when sharing one implementation
|
|
435
|
+
if AssignParser._share_one_implementation:
|
|
436
|
+
AssignParser._cached_trees[id(cell)] = new_stree
|
|
437
|
+
else:
|
|
438
|
+
sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args,
|
|
439
|
+
self.kwargs, cell_name)
|
|
440
|
+
# add sub node to cell_container
|
|
441
|
+
cell_container.append(sub_node, False)
|
|
442
|
+
# save the node if container is firstly parsed when sharing one implementation
|
|
443
|
+
if AssignParser._share_one_implementation:
|
|
444
|
+
AssignParser._cached_cell_containers[id(container_obj)] = cell_container
|
|
445
|
+
return cell_container
|
|
446
|
+
|
|
447
|
+
def process_cell(self, func_scope_name: ScopedValue, node_name: str, cell_inst: Cell):
|
|
448
|
+
"""Create CallCell node with instance of cell."""
|
|
449
|
+
# The type of cell is container of cells (e.g. SequentialCell)
|
|
450
|
+
if isinstance(cell_inst, tuple(AssignParser.types_for_cell_container)):
|
|
451
|
+
node = self.cell_container_process(func_scope_name, node_name, cell_inst)
|
|
452
|
+
# The type of cell is user custom network, then we create sub-symboltree
|
|
453
|
+
elif is_subtree(cell_inst):
|
|
454
|
+
# create unparsable node if tree node is already parsed when sharing one implementation
|
|
455
|
+
if AssignParser._share_one_implementation and id(cell_inst) in AssignParser._cached_trees:
|
|
456
|
+
first_stree = AssignParser._cached_trees.get(id(cell_inst))
|
|
457
|
+
self._update_field_in_init(str(func_scope_name), first_stree)
|
|
458
|
+
node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name,
|
|
459
|
+
self.args, self.kwargs, node_name)
|
|
460
|
+
else:
|
|
461
|
+
from ..symbol_tree import SymbolTreeBuilder
|
|
462
|
+
stb = SymbolTreeBuilder(cell_inst)
|
|
463
|
+
new_stree = stb.build()
|
|
464
|
+
self._update_field_in_init(str(func_scope_name), new_stree)
|
|
465
|
+
node = TreeNode.create_tree_node(new_stree, self.ast_assign, self.targets, func_scope_name,
|
|
466
|
+
self.args, self.kwargs, node_name, new_stree.get_origin_network())
|
|
467
|
+
# save symbol tree if it is firstly parsed when sharing one implementation
|
|
468
|
+
if AssignParser._share_one_implementation:
|
|
469
|
+
AssignParser._cached_trees[id(cell_inst)] = new_stree
|
|
470
|
+
else:
|
|
471
|
+
# The type of cell is built-in cells
|
|
472
|
+
node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name, self.args,
|
|
473
|
+
self.kwargs, node_name)
|
|
474
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
475
|
+
|
|
476
|
+
def process_primitive(self, func_scope_name: ScopedValue, node_name: str, primitive_inst: Primitive):
|
|
477
|
+
"""Create CallPrimitive node with instance of primitive."""
|
|
478
|
+
node = Node.create_call_buildin_op(primitive_inst, self.ast_assign, self.targets, func_scope_name,
|
|
479
|
+
self.args, self.kwargs, node_name)
|
|
480
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
481
|
+
|
|
482
|
+
def process_class_method(self, func_scope_name: ScopedValue, node_name: str, method_object: object):
|
|
483
|
+
"""Create CallFunction node for class method function."""
|
|
484
|
+
func_name = func_scope_name.value
|
|
485
|
+
# get ast.FunctionDef
|
|
486
|
+
ast_functiondef = None
|
|
487
|
+
for body in self.stree.get_class_ast().body:
|
|
488
|
+
if isinstance(body, ast.FunctionDef) and func_name == body.name:
|
|
489
|
+
ast_functiondef = body
|
|
490
|
+
if ast_functiondef is None:
|
|
491
|
+
# method of child class may be called and will be ignored now.
|
|
492
|
+
logger.info(error_str(f"Find ast of function '{func_name}' in network '{self.stree.get_ori_cls_name()}' "
|
|
493
|
+
f"failed", child_node=self.ast_assign))
|
|
494
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, None, False)
|
|
495
|
+
else:
|
|
496
|
+
# create CallFunction node
|
|
497
|
+
self.insert_callfunction_node(func_scope_name, node_name, ast_functiondef, method_object, True)
|
|
498
|
+
|
|
499
|
+
def process_function(self, func_scope_name: ScopedValue, node_name: str, function_object: object,
|
|
500
|
+
is_cls_type_obj: bool):
|
|
501
|
+
"""Create node for function."""
|
|
502
|
+
# Ignore functions in _function_parse_black_list
|
|
503
|
+
if function_object in AssignParser._function_parse_black_list:
|
|
504
|
+
logger.debug(f"'{func_scope_name}' is in the _function_parse_black_list and will not be parsed")
|
|
505
|
+
if not func_scope_name.scope:
|
|
506
|
+
self._add_import(func_scope_name.value)
|
|
507
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
508
|
+
return
|
|
509
|
+
# break loop function
|
|
510
|
+
node_manager = self.node_manager
|
|
511
|
+
while node_manager and isinstance(node_manager, Node):
|
|
512
|
+
if isinstance(node_manager, CallFunction) and node_manager.get_instance() == function_object:
|
|
513
|
+
logger.info(f"loop function detected in '{func_scope_name}', stop parsing function.")
|
|
514
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
515
|
+
return
|
|
516
|
+
node_manager = node_manager.get_node_manager()
|
|
517
|
+
# process primitive instances:
|
|
518
|
+
# (global/local) _ops_func = P.FUNC()
|
|
519
|
+
# (here) y = _ops_func(x) <- (process: _ops_func)
|
|
520
|
+
if isinstance(function_object, Primitive):
|
|
521
|
+
# when primitive instance is not a local variable, it will be a global object which need to be imported
|
|
522
|
+
if not isinstance(function_object, LocalPrim):
|
|
523
|
+
import_name = str(func_scope_name).split('.')[0]
|
|
524
|
+
self._add_import(import_name)
|
|
525
|
+
# create CallPrimitive node
|
|
526
|
+
self.process_primitive(func_scope_name, func_scope_name.value, function_object)
|
|
527
|
+
return
|
|
528
|
+
# process primitive object:
|
|
529
|
+
# (here) _ops_func = P.FUNC() <- (process: P.FUNC)
|
|
530
|
+
# (later) y = _ops_func(x)
|
|
531
|
+
if inspect.isclass(function_object):
|
|
532
|
+
node = self.insert_callfunction_node(func_scope_name, node_name, None, None, False)
|
|
533
|
+
if is_cls_type_obj:
|
|
534
|
+
# represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs)
|
|
535
|
+
node.set_type_cls(function_object)
|
|
536
|
+
# add import
|
|
537
|
+
if str(func_scope_name) == '_get_cache_prim':
|
|
538
|
+
import_name = astunparse.unparse(self.ast_assign.value.args[0]).strip()
|
|
539
|
+
if '.' not in import_name:
|
|
540
|
+
self._add_import(import_name)
|
|
541
|
+
else:
|
|
542
|
+
# represent the initialize of a class type, e.g. abs_inst = P.Abs()
|
|
543
|
+
node.set_init_cls(function_object)
|
|
544
|
+
# record local primitive objects
|
|
545
|
+
if func_scope_name.scope == 'self' and issubclass(function_object, Primitive):
|
|
546
|
+
self.stree.local_prim_inits.append(node)
|
|
547
|
+
return
|
|
548
|
+
# process third party functions
|
|
549
|
+
is_ms_func = is_ms_function(function_object)
|
|
550
|
+
if not is_ms_func and is_third_party(function_object):
|
|
551
|
+
logger.info(f"Ignore third party function '{func_scope_name}'.")
|
|
552
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
553
|
+
return
|
|
554
|
+
# process mindspore functions
|
|
555
|
+
if is_ms_func and AssignParser._skip_ms_function:
|
|
556
|
+
logger.info(f"Ignore mindspore function '{func_scope_name}'.")
|
|
557
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
558
|
+
return
|
|
559
|
+
# get ast.FunctionDef
|
|
560
|
+
source_code = inspect.getsource(function_object)
|
|
561
|
+
ast_functiondef = ast.parse(dedent(source_code)).body[0]
|
|
562
|
+
if not isinstance(ast_functiondef, ast.FunctionDef):
|
|
563
|
+
logger.info(error_str(f"Get ast.FunctionDef of function {str(func_scope_name)} failed, the type of "
|
|
564
|
+
f"ast node is {type(ast_functiondef)}", child_node=self.ast_assign))
|
|
565
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
566
|
+
return
|
|
567
|
+
if [n for n in ast_functiondef.body if isinstance(n, ast.FunctionDef)]:
|
|
568
|
+
logger.info(error_str(f"closure syntax is not supported now, {str(func_scope_name)} will not be parsed.",
|
|
569
|
+
child_node=ast_functiondef))
|
|
570
|
+
if not func_scope_name.scope:
|
|
571
|
+
self._add_import(func_scope_name.value)
|
|
572
|
+
self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False)
|
|
573
|
+
return
|
|
574
|
+
# update func_name, and remove scope
|
|
575
|
+
new_name = ast_functiondef.name
|
|
576
|
+
# when func_scope_name(e.g. 'C.uniform') is not the name in ast.FunctionDef(e.g. 'uniform'), this name may be
|
|
577
|
+
# already used as variable(e.g. uniform = C.uniform(x)).
|
|
578
|
+
# To avoid new function's name being duplicated with existed variable, an suffix '_opt' will be added.
|
|
579
|
+
if new_name != str(func_scope_name):
|
|
580
|
+
new_name = f"{new_name}_opt"
|
|
581
|
+
new_name = FunctionNamer().instance().get_name(new_name)
|
|
582
|
+
# create unparsable node if function is already parsed when sharing one implementation
|
|
583
|
+
if AssignParser._share_one_implementation and id(function_object) in AssignParser._cached_functions:
|
|
584
|
+
first_node = AssignParser._cached_functions.get(id(function_object))
|
|
585
|
+
ast_call: ast.Call = self.ast_assign.value
|
|
586
|
+
ast_call.func = ast.Name(id=str(first_node.get_func_name()), ctx=ast.Load())
|
|
587
|
+
self.insert_callfunction_node(func_scope_name, new_name, None, function_object, False)
|
|
588
|
+
return
|
|
589
|
+
ast_functiondef.name = new_name
|
|
590
|
+
ast_call: ast.Call = self.ast_assign.value
|
|
591
|
+
ast_call.func = ast.Name(id=new_name, ctx=ast.Load())
|
|
592
|
+
# save ast.FunctionDef into stree._external_ast
|
|
593
|
+
self.stree.get_external_ast()[ast_functiondef] = []
|
|
594
|
+
# import module which function defined in
|
|
595
|
+
func_file_path = inspect.getabsfile(function_object)
|
|
596
|
+
self.stree.save_imports_from_file(func_file_path, ast_functiondef)
|
|
597
|
+
# create CallFunction node
|
|
598
|
+
func_scope_name = ScopedValue.create_naming_value(new_name, "")
|
|
599
|
+
node = self.insert_callfunction_node(func_scope_name, new_name, ast_functiondef, function_object, False)
|
|
600
|
+
# save function node if it is firstly parsed when sharing one implementation
|
|
601
|
+
if AssignParser._share_one_implementation:
|
|
602
|
+
AssignParser._cached_functions[id(function_object)] = node
|
|
603
|
+
|
|
604
|
+
def insert_callfunction_node(self, func_name: ScopedValue, node_name: str, ast_functiondef: ast.FunctionDef,
|
|
605
|
+
func_obj: object, is_method: bool) -> Node:
|
|
606
|
+
"""Create CallFunction node for function."""
|
|
607
|
+
if ast_functiondef is None:
|
|
608
|
+
node = Node.inner_create_call_function(node_name, self.ast_assign, func_name, func_obj,
|
|
609
|
+
self.targets, self.args, self.kwargs)
|
|
610
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
611
|
+
return node
|
|
612
|
+
# create CallFunction node
|
|
613
|
+
node = CallFunction(self.targets, func_name, self.args, self.kwargs, node_name, self.ast_assign,
|
|
614
|
+
ast_functiondef, self.stree, func_obj, is_method)
|
|
615
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
616
|
+
# expand ast codes
|
|
617
|
+
ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name.value], self.stree)
|
|
618
|
+
# parse ast codes into CallFunction Node
|
|
619
|
+
parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
620
|
+
parser.process(self.stree, ast_functiondef, node_manager=node)
|
|
621
|
+
return node
|
|
622
|
+
|
|
623
|
+
def process_ast_call(self, ast_call: ast.Call):
|
|
624
|
+
"""
|
|
625
|
+
Convert ast.Call to a symbol tree node.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
ast_call (ast.Call): An ast.Call of assign node in construct.
|
|
629
|
+
"""
|
|
630
|
+
self.targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
631
|
+
self.args = [AstConverter.create_scopedvalue(arg) for arg in ast_call.args]
|
|
632
|
+
self.kwargs = AssignParser._create_kwargs(ast_call.keywords)
|
|
633
|
+
func_name = AssignParser._get_func_name(ast_call)
|
|
634
|
+
func_scope = AssignParser._get_func_scope(ast_call)
|
|
635
|
+
func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
|
|
636
|
+
func_full_name = str(func_scope_name)
|
|
637
|
+
# y = func(xxx)(xxx) / y = func1(xxx).func2(xxx) is not supported, and should be flattened before parsing.
|
|
638
|
+
if AstFinder(ast_call.func).find_all(ast.Call):
|
|
639
|
+
logger.info(error_str("ast.Call in func name of ast.Call is not supported.", ast_call, self.ast_assign))
|
|
640
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
641
|
+
return
|
|
642
|
+
# Ignore built-in functions
|
|
643
|
+
if func_full_name in dir(builtins):
|
|
644
|
+
logger.info(f"Ignore built-in function: {func_scope_name}")
|
|
645
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
646
|
+
return
|
|
647
|
+
# Ignore function name is target of for loop
|
|
648
|
+
if isinstance(self.node_manager, ControlFlow) and func_full_name in self.node_manager.loop_vars:
|
|
649
|
+
logger.info(f"Ignore function of loop variable: {func_scope_name}")
|
|
650
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
651
|
+
return
|
|
652
|
+
# Instance with type of Cell
|
|
653
|
+
cell_inst = self._get_cell_instance(func_scope, func_name)
|
|
654
|
+
if cell_inst is not None:
|
|
655
|
+
self.process_cell(func_scope_name, func_name, cell_inst)
|
|
656
|
+
return
|
|
657
|
+
# Instance with type of Primitive
|
|
658
|
+
primitive_inst = self._get_primitive_instance(func_scope, func_name)
|
|
659
|
+
if primitive_inst is not None:
|
|
660
|
+
self.process_primitive(func_scope_name, func_name, primitive_inst)
|
|
661
|
+
return
|
|
662
|
+
# Class method object
|
|
663
|
+
method_object = self._get_method_object(func_scope, func_name)
|
|
664
|
+
if method_object is not None:
|
|
665
|
+
if inspect.ismethod(method_object):
|
|
666
|
+
self.process_class_method(func_scope_name, func_name, method_object)
|
|
667
|
+
elif isinstance(inspect.getattr_static(self.stree.get_origin_network(), func_name), staticmethod):
|
|
668
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
669
|
+
else:
|
|
670
|
+
self.process_function(func_scope_name, func_name, method_object, False)
|
|
671
|
+
return
|
|
672
|
+
# Local variable
|
|
673
|
+
is_local_var, primitive_obj = self._get_local_variable(func_scope, func_name)
|
|
674
|
+
if primitive_obj is not None:
|
|
675
|
+
self.process_function(func_scope_name, func_name, primitive_obj, False)
|
|
676
|
+
return
|
|
677
|
+
if is_local_var:
|
|
678
|
+
# for a variable whose type is not primitive instance, create normal node for it
|
|
679
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
680
|
+
return
|
|
681
|
+
# Function object
|
|
682
|
+
function_object, is_cls_type_obj = self._get_function_object(func_scope, func_name, ast_call)
|
|
683
|
+
if function_object is not None:
|
|
684
|
+
self.process_function(func_scope_name, func_name, function_object, is_cls_type_obj)
|
|
685
|
+
return
|
|
686
|
+
logger.info(error_str("Failed to get instance or object of ast.Call.", ast_call, self.ast_assign))
|
|
687
|
+
self.insert_callfunction_node(func_scope_name, func_name, None, None, False)
|
|
688
|
+
|
|
689
|
+
def process_ast_mathops(self, ast_op: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]):
|
|
690
|
+
"""
|
|
691
|
+
Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
|
|
692
|
+
a symbol tree node.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
ast_op (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
|
|
696
|
+
operation in construct function.
|
|
697
|
+
|
|
698
|
+
Raises:
|
|
699
|
+
TypeError: The type of parameter 'ast_op' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
|
|
700
|
+
|
|
701
|
+
"""
|
|
702
|
+
if not isinstance(ast_op, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
|
703
|
+
raise TypeError("The type of parameter 'ast_op' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
704
|
+
"ast.BoolOp, ast.Compare), but got ", type(ast_op))
|
|
705
|
+
|
|
706
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
707
|
+
args = []
|
|
708
|
+
op_type_str = type(ast_op).__name__
|
|
709
|
+
op_type = ScopedValue.create_naming_value(op_type_str)
|
|
710
|
+
name = op_type_str
|
|
711
|
+
if isinstance(ast_op, ast.BinOp):
|
|
712
|
+
op = type(ast_op.op).__name__
|
|
713
|
+
name = f'{name}_{op}'
|
|
714
|
+
args.append(AstConverter.create_scopedvalue(ast_op.left))
|
|
715
|
+
args.append(AstConverter.create_scopedvalue(ast_op.right))
|
|
716
|
+
elif isinstance(ast_op, ast.UnaryOp):
|
|
717
|
+
op = type(ast_op.op).__name__
|
|
718
|
+
name = f'{name}_{op}'
|
|
719
|
+
args.append(AstConverter.create_scopedvalue(ast_op.operand))
|
|
720
|
+
elif isinstance(ast_op, ast.BoolOp):
|
|
721
|
+
op = type(ast_op.op).__name__
|
|
722
|
+
name = f'{name}_{op}'
|
|
723
|
+
for value in ast_op.values:
|
|
724
|
+
args.append(AstConverter.create_scopedvalue(value))
|
|
725
|
+
elif isinstance(ast_op, ast.Compare):
|
|
726
|
+
args.append(AstConverter.create_scopedvalue(ast_op.left))
|
|
727
|
+
for idx, ast_cmp_op in enumerate(ast_op.ops):
|
|
728
|
+
op = type(ast_cmp_op).__name__
|
|
729
|
+
name = f'{name}_{op}'
|
|
730
|
+
args.append(AstConverter.create_scopedvalue(ast_op.comparators[idx]))
|
|
731
|
+
name = name.lower()
|
|
732
|
+
node = Node.create_mathops_node(self.ast_assign, targets, op_type, args, name)
|
|
733
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
734
|
+
|
|
735
|
+
def process_ast_constant(self, ast_constant: Union[ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str]):
|
|
736
|
+
"""
|
|
737
|
+
Convert ast node of constant types (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str) to
|
|
738
|
+
a symbol tree node.
|
|
739
|
+
"""
|
|
740
|
+
node_name = f"{type(ast_constant).__name__.lower()}_assign"
|
|
741
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
742
|
+
args = [AstConverter.create_scopedvalue(ast_constant)]
|
|
743
|
+
node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, node_name)
|
|
744
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
745
|
+
|
|
746
|
+
def process_ast_name(self, ast_node: Union[ast.Name, ast.Attribute]):
|
|
747
|
+
"""
|
|
748
|
+
Convert ast node of ast.Name and ast.Attribute to a symbol tree node.
|
|
749
|
+
"""
|
|
750
|
+
self.targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
751
|
+
inst, scope_name = AssignParser._get_inst_and_name(ast_node, self.stree)
|
|
752
|
+
if inst is not None and (isinstance(inst, CellList) or
|
|
753
|
+
isinstance(inst, list) and AssignParser._list_of_cells(inst)):
|
|
754
|
+
node = self.cell_container_process(scope_name, scope_name, inst)
|
|
755
|
+
else:
|
|
756
|
+
node_name = f"{type(ast_node).__name__.lower()}_assign"
|
|
757
|
+
args = [AstConverter.create_scopedvalue(ast_node)]
|
|
758
|
+
node = Node.create_call_method(self.ast_assign, self.targets, "pass_through", args, {}, node_name)
|
|
759
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
760
|
+
|
|
761
|
+
def process_ast_tuple(self, ast_node: Union[ast.Tuple, ast.List]):
|
|
762
|
+
"""
|
|
763
|
+
Convert ast node of ast.Tuple or ast.List to a symbol tree node.
|
|
764
|
+
"""
|
|
765
|
+
# ensure that each element's type in tuple is supported by scopled value
|
|
766
|
+
if AstConverter.ast_tuple_elts_support_scopledvalue(ast_node):
|
|
767
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
768
|
+
args = []
|
|
769
|
+
for elt in ast_node.elts:
|
|
770
|
+
args.append(AstConverter.create_scopedvalue(elt))
|
|
771
|
+
func_name = "tuple" if isinstance(ast_node, ast.Tuple) else "list"
|
|
772
|
+
node = Node.create_call_method(self.ast_assign, targets, func_name, args, {}, func_name)
|
|
773
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
774
|
+
else:
|
|
775
|
+
logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported "
|
|
776
|
+
"in rewrite, fallback to python")
|
|
777
|
+
self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager)
|
|
778
|
+
|
|
779
|
+
def process_ast_dict(self, ast_dict: ast.Dict):
|
|
780
|
+
"""
|
|
781
|
+
Convert ast node of ast.Dict to a symbol tree node.
|
|
782
|
+
"""
|
|
783
|
+
# ensure that each element's type in dict is supported by scopled value
|
|
784
|
+
if AstConverter.ast_dict_support_scopledvalue(ast_dict):
|
|
785
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
786
|
+
kwargs = {}
|
|
787
|
+
for idx, key in enumerate(ast_dict.keys):
|
|
788
|
+
kwargs[key.value] = AstConverter.create_scopedvalue(ast_dict.values[idx])
|
|
789
|
+
func_name = ScopedValue.create_naming_value("dict")
|
|
790
|
+
node = Node.create_call_method(self.ast_assign, targets, func_name, [], kwargs, "dict")
|
|
791
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
792
|
+
else:
|
|
793
|
+
logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported "
|
|
794
|
+
"in rewrite, fallback to python")
|
|
795
|
+
self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager)
|
|
796
|
+
|
|
797
|
+
def process_ast_subscript(self, ast_subscript: ast.Subscript):
|
|
798
|
+
"""
|
|
799
|
+
Convert ast node of ast.Subscript to a symbol tree node.
|
|
800
|
+
"""
|
|
801
|
+
targets = AssignParser._create_targets(self.ast_assign.targets[0])
|
|
802
|
+
args = [AstConverter.create_scopedvalue(ast_subscript)]
|
|
803
|
+
node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, "subscript_var")
|
|
804
|
+
self.stree.append_origin_field(node, self.node_manager)
|
|
805
|
+
|
|
806
|
+
def process(self, stree: SymbolTree, node: ast.Assign, node_manager: NodeManager):
|
|
807
|
+
"""
|
|
808
|
+
Parse ast.Assign and create a node in symbol tree.
|
|
809
|
+
|
|
810
|
+
- Create node when value of ast.Assign is in [ast.Call, ast.Name, ast.Constant, ast.Attribute].
|
|
811
|
+
- Create python node when value of ast.Assign is in [ast.BinOp, ast.BoolOp, ast.Subscript, ast.List, ast.Tuple,
|
|
812
|
+
ast.Dict].
|
|
813
|
+
- Other value types are not supported.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
817
|
+
node ([ast.Assign]): An ast.Assign node.
|
|
818
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
819
|
+
"""
|
|
820
|
+
if len(node.targets) != 1:
|
|
821
|
+
logger.info(error_str(f"Continuous assignment statement(e.g. 'a = b = 1') should be flatten before.",
|
|
822
|
+
child_node=node))
|
|
823
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
824
|
+
return
|
|
825
|
+
|
|
826
|
+
self.store_env()
|
|
827
|
+
self.stree = stree
|
|
828
|
+
self.ast_assign = node
|
|
829
|
+
self.node_manager = node_manager
|
|
830
|
+
value = node.value
|
|
831
|
+
if isinstance(value, ast.Call):
|
|
832
|
+
self.process_ast_call(value)
|
|
833
|
+
elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
|
834
|
+
self.process_ast_mathops(value)
|
|
835
|
+
elif isinstance(value, ast.Subscript):
|
|
836
|
+
self.process_ast_subscript(value)
|
|
837
|
+
elif isinstance(value, (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str)):
|
|
838
|
+
self.process_ast_constant(value)
|
|
839
|
+
elif isinstance(value, (ast.Name, ast.Attribute)):
|
|
840
|
+
self.process_ast_name(value)
|
|
841
|
+
elif isinstance(value, (ast.Tuple, ast.List)):
|
|
842
|
+
self.process_ast_tuple(value)
|
|
843
|
+
elif isinstance(value, ast.Dict):
|
|
844
|
+
self.process_ast_dict(value)
|
|
845
|
+
else:
|
|
846
|
+
logger.info(f"ops-call({astunparse.unparse(node).strip()}) in assign will be supported in near feature, "
|
|
847
|
+
f"ignored as a python node now")
|
|
848
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
849
|
+
self.restore_env()
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
g_assign_parser = reg_parser(AssignParser())
|