mindspore 2.4.0__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-310-darwin.so +0 -0
- mindspore/_c_expression.cpython-310-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-310-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,827 @@
|
|
|
1
|
+
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""basic"""
|
|
16
|
+
from mindspore import context
|
|
17
|
+
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
19
|
+
from mindspore.nn.cell import Cell
|
|
20
|
+
from mindspore.ops.primitive import constexpr
|
|
21
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
22
|
+
from mindspore import _checkparam as validator
|
|
23
|
+
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
|
|
24
|
+
raise_not_implemented_util
|
|
25
|
+
from ._utils.utils import CheckTuple, CheckTensor
|
|
26
|
+
from ._utils.custom_ops import broadcast_to, exp_generic, log_generic
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Distribution(Cell):
|
|
30
|
+
"""
|
|
31
|
+
Base class for all mathematical distributions.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
seed (int): The seed is used in sampling. 0 is used if it is None.
|
|
35
|
+
dtype (mindspore.dtype): The type of the event samples.
|
|
36
|
+
name (str): The name of the distribution.
|
|
37
|
+
param (dict): The parameters used to initialize the distribution.
|
|
38
|
+
|
|
39
|
+
Note:
|
|
40
|
+
Derived class must override operations such as `_mean`, `_prob`,
|
|
41
|
+
and `_log_prob`. Required arguments, such as `value` for `_prob`,
|
|
42
|
+
must be passed in through `args` or `kwargs`. `dist_spec_args` which specifies
|
|
43
|
+
a new distribution are optional.
|
|
44
|
+
|
|
45
|
+
`dist_spec_args` is unique for each type of distribution. For example, `mean` and `sd`
|
|
46
|
+
are the `dist_spec_args` for a Normal distribution, while `rate` is the `dist_spec_args`
|
|
47
|
+
for an Exponential distribution.
|
|
48
|
+
|
|
49
|
+
For all functions, passing in `dist_spec_args`, is optional.
|
|
50
|
+
Function calls with the additional `dist_spec_args` passed in will evaluate the result with
|
|
51
|
+
a new distribution specified by the `dist_spec_args`. However, it will not change the original distribution.
|
|
52
|
+
|
|
53
|
+
Supported Platforms:
|
|
54
|
+
``Ascend`` ``GPU``
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self,
|
|
58
|
+
seed,
|
|
59
|
+
dtype,
|
|
60
|
+
name,
|
|
61
|
+
param):
|
|
62
|
+
"""
|
|
63
|
+
Constructor of distribution class.
|
|
64
|
+
"""
|
|
65
|
+
super(Distribution, self).__init__()
|
|
66
|
+
if seed is None:
|
|
67
|
+
seed = 0
|
|
68
|
+
validator.check_value_type('name', name, [str], type(self).__name__)
|
|
69
|
+
validator.check_non_negative_int(seed, 'seed', name)
|
|
70
|
+
|
|
71
|
+
self._name = name
|
|
72
|
+
self._seed = seed
|
|
73
|
+
self._dtype = cast_type_for_device(dtype)
|
|
74
|
+
self._parameters = {}
|
|
75
|
+
self.default_parameters = []
|
|
76
|
+
self.parameter_names = []
|
|
77
|
+
|
|
78
|
+
# parsing parameters
|
|
79
|
+
for k in param.keys():
|
|
80
|
+
if not(k == 'self' or k.startswith('_')):
|
|
81
|
+
self._parameters[k] = param[k]
|
|
82
|
+
|
|
83
|
+
# if not a transformed distribution, set the following attribute
|
|
84
|
+
if 'distribution' not in self.parameters.keys():
|
|
85
|
+
self.parameter_type = set_param_type(
|
|
86
|
+
self.parameters.get('param_dict', {}), dtype)
|
|
87
|
+
self._batch_shape = self._calc_batch_shape()
|
|
88
|
+
self._is_scalar_batch = self._check_is_scalar_batch()
|
|
89
|
+
self._broadcast_shape = self._batch_shape
|
|
90
|
+
|
|
91
|
+
# set the function to call according to the derived class's attributes
|
|
92
|
+
self._set_prob()
|
|
93
|
+
self._set_log_prob()
|
|
94
|
+
self._set_sd()
|
|
95
|
+
self._set_var()
|
|
96
|
+
self._set_cdf()
|
|
97
|
+
self._set_survival()
|
|
98
|
+
self._set_log_cdf()
|
|
99
|
+
self._set_log_survival()
|
|
100
|
+
self._set_cross_entropy()
|
|
101
|
+
|
|
102
|
+
self.context_mode = context.get_context('mode')
|
|
103
|
+
self.device_target = context.get_context('device_target')
|
|
104
|
+
self.checktuple = CheckTuple()
|
|
105
|
+
|
|
106
|
+
@constexpr(check=False)
|
|
107
|
+
def _check_tensor(x, name):
|
|
108
|
+
CheckTensor()(x, name)
|
|
109
|
+
return x
|
|
110
|
+
# we use constexpr to force CheckTensor to run only once in pynative mode
|
|
111
|
+
self.checktensor = CheckTensor() if self.context_mode == 0 else _check_tensor
|
|
112
|
+
self.broadcast = broadcast_to
|
|
113
|
+
|
|
114
|
+
# ops needed for the base class
|
|
115
|
+
self.cast_base = P.Cast()
|
|
116
|
+
self.dtype_base = P.DType()
|
|
117
|
+
self.sametypeshape_base = inner.SameTypeShape()
|
|
118
|
+
self.sq_base = P.Square()
|
|
119
|
+
self.sqrt_base = P.Sqrt()
|
|
120
|
+
self.shape_base = P.Shape()
|
|
121
|
+
if self.device_target != "Ascend":
|
|
122
|
+
self.log_base = P.Log()
|
|
123
|
+
self.exp_base = P.Exp()
|
|
124
|
+
else:
|
|
125
|
+
self.exp_base = exp_generic
|
|
126
|
+
self.log_base = log_generic
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def name(self):
|
|
130
|
+
return self._name
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def dtype(self):
|
|
134
|
+
return self._dtype
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def seed(self):
|
|
138
|
+
return self._seed
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def parameters(self):
|
|
142
|
+
return self._parameters
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def is_scalar_batch(self):
|
|
146
|
+
return self._is_scalar_batch
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def batch_shape(self):
|
|
150
|
+
return self._batch_shape
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def broadcast_shape(self):
|
|
154
|
+
return self._broadcast_shape
|
|
155
|
+
|
|
156
|
+
def _reset_parameters(self):
|
|
157
|
+
self.default_parameters = []
|
|
158
|
+
self.parameter_names = []
|
|
159
|
+
|
|
160
|
+
def _add_parameter(self, value, name):
|
|
161
|
+
"""
|
|
162
|
+
Cast `value` to a tensor and add it to `self.default_parameters`.
|
|
163
|
+
Add `name` into and `self.parameter_names`.
|
|
164
|
+
"""
|
|
165
|
+
# initialize the attributes if they do not exist yet
|
|
166
|
+
if not hasattr(self, 'default_parameters'):
|
|
167
|
+
self.default_parameters = []
|
|
168
|
+
self.parameter_names = []
|
|
169
|
+
# cast value to a tensor if it is not None
|
|
170
|
+
value_t = None if value is None else cast_to_tensor(value, self.parameter_type)
|
|
171
|
+
self.default_parameters.append(value_t)
|
|
172
|
+
self.parameter_names.append(name)
|
|
173
|
+
return value_t
|
|
174
|
+
|
|
175
|
+
def _check_param_type(self, *args):
|
|
176
|
+
"""
|
|
177
|
+
Check the availability and validity of default parameters and `dist_spec_args`.
|
|
178
|
+
`dist_spec_args` passed in must be tensors. If default parameters of the distribution
|
|
179
|
+
are None, the parameters must be passed in through `args`.
|
|
180
|
+
"""
|
|
181
|
+
broadcast_shape = None
|
|
182
|
+
broadcast_shape_tensor = None
|
|
183
|
+
common_dtype = None
|
|
184
|
+
out = []
|
|
185
|
+
|
|
186
|
+
for arg, name, default in zip(args, self.parameter_names, self.default_parameters):
|
|
187
|
+
# check if the argument is a Tensor
|
|
188
|
+
if arg is not None:
|
|
189
|
+
self.checktensor(arg, name)
|
|
190
|
+
else:
|
|
191
|
+
arg = default if default is not None else raise_none_error(name)
|
|
192
|
+
|
|
193
|
+
# broadcast if the number of args > 1
|
|
194
|
+
if broadcast_shape is None:
|
|
195
|
+
broadcast_shape = self.shape_base(arg)
|
|
196
|
+
common_dtype = self.dtype_base(arg)
|
|
197
|
+
broadcast_shape_tensor = F.fill(
|
|
198
|
+
common_dtype, broadcast_shape, 1.0)
|
|
199
|
+
else:
|
|
200
|
+
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
|
|
201
|
+
broadcast_shape_tensor = F.fill(
|
|
202
|
+
common_dtype, broadcast_shape, 1.0)
|
|
203
|
+
arg = self.broadcast(arg, broadcast_shape_tensor)
|
|
204
|
+
# check if the arguments have the same dtype
|
|
205
|
+
self.sametypeshape_base(arg, broadcast_shape_tensor)
|
|
206
|
+
|
|
207
|
+
arg = self.cast_base(arg, self.parameter_type)
|
|
208
|
+
out.append(arg)
|
|
209
|
+
|
|
210
|
+
if len(out) == 1:
|
|
211
|
+
return out[0]
|
|
212
|
+
|
|
213
|
+
# broadcast all args to broadcast_shape
|
|
214
|
+
result = ()
|
|
215
|
+
for arg in out:
|
|
216
|
+
arg = self.broadcast(arg, broadcast_shape_tensor)
|
|
217
|
+
result = result + (arg,)
|
|
218
|
+
return result
|
|
219
|
+
|
|
220
|
+
def _check_value(self, value, name):
|
|
221
|
+
"""
|
|
222
|
+
Check availability of `value` as a Tensor.
|
|
223
|
+
"""
|
|
224
|
+
self.checktensor(value, name)
|
|
225
|
+
return value
|
|
226
|
+
|
|
227
|
+
def _check_is_scalar_batch(self):
|
|
228
|
+
"""
|
|
229
|
+
Check if the parameters used during initialization are scalars.
|
|
230
|
+
"""
|
|
231
|
+
param_dict = self.parameters.get('param_dict')
|
|
232
|
+
for value in param_dict.values():
|
|
233
|
+
if value is None:
|
|
234
|
+
continue
|
|
235
|
+
if not isinstance(value, (int, float)):
|
|
236
|
+
return False
|
|
237
|
+
return True
|
|
238
|
+
|
|
239
|
+
def _calc_batch_shape(self):
|
|
240
|
+
"""
|
|
241
|
+
Calculate the broadcast shape of the parameters used during initialization.
|
|
242
|
+
"""
|
|
243
|
+
broadcast_shape_tensor = None
|
|
244
|
+
param_dict = self.parameters.get('param_dict')
|
|
245
|
+
for value in param_dict.values():
|
|
246
|
+
if value is None:
|
|
247
|
+
return None
|
|
248
|
+
if broadcast_shape_tensor is None:
|
|
249
|
+
broadcast_shape_tensor = cast_to_tensor(value)
|
|
250
|
+
else:
|
|
251
|
+
value = cast_to_tensor(value)
|
|
252
|
+
broadcast_shape_tensor = (value + broadcast_shape_tensor)
|
|
253
|
+
return broadcast_shape_tensor.shape
|
|
254
|
+
|
|
255
|
+
def _set_prob(self):
|
|
256
|
+
"""
|
|
257
|
+
Set probability function based on the availability of `_prob` and `_log_likehood`.
|
|
258
|
+
"""
|
|
259
|
+
if hasattr(self, '_prob'):
|
|
260
|
+
self._call_prob = self._prob
|
|
261
|
+
elif hasattr(self, '_log_prob'):
|
|
262
|
+
self._call_prob = self._calc_prob_from_log_prob
|
|
263
|
+
else:
|
|
264
|
+
self._call_prob = self._raise_not_implemented_error('prob')
|
|
265
|
+
|
|
266
|
+
def _set_sd(self):
|
|
267
|
+
"""
|
|
268
|
+
Set standard deviation based on the availability of `_sd` and `_var`.
|
|
269
|
+
"""
|
|
270
|
+
if hasattr(self, '_sd'):
|
|
271
|
+
self._call_sd = self._sd
|
|
272
|
+
elif hasattr(self, '_var'):
|
|
273
|
+
self._call_sd = self._calc_sd_from_var
|
|
274
|
+
else:
|
|
275
|
+
self._call_sd = self._raise_not_implemented_error('sd')
|
|
276
|
+
|
|
277
|
+
def _set_var(self):
|
|
278
|
+
"""
|
|
279
|
+
Set variance based on the availability of `_sd` and `_var`.
|
|
280
|
+
"""
|
|
281
|
+
if hasattr(self, '_var'):
|
|
282
|
+
self._call_var = self._var
|
|
283
|
+
elif hasattr(self, '_sd'):
|
|
284
|
+
self._call_var = self._calc_var_from_sd
|
|
285
|
+
else:
|
|
286
|
+
self._call_var = self._raise_not_implemented_error('var')
|
|
287
|
+
|
|
288
|
+
def _set_log_prob(self):
|
|
289
|
+
"""
|
|
290
|
+
Set log probability based on the availability of `_prob` and `_log_prob`.
|
|
291
|
+
"""
|
|
292
|
+
if hasattr(self, '_log_prob'):
|
|
293
|
+
self._call_log_prob = self._log_prob
|
|
294
|
+
elif hasattr(self, '_prob'):
|
|
295
|
+
self._call_log_prob = self._calc_log_prob_from_prob
|
|
296
|
+
else:
|
|
297
|
+
self._call_log_prob = self._raise_not_implemented_error('log_prob')
|
|
298
|
+
|
|
299
|
+
def _set_cdf(self):
|
|
300
|
+
"""
|
|
301
|
+
Set cumulative distribution function (cdf) based on the availability of `_cdf` and `_log_cdf` and
|
|
302
|
+
`survival_functions`.
|
|
303
|
+
"""
|
|
304
|
+
if hasattr(self, '_cdf'):
|
|
305
|
+
self._call_cdf = self._cdf
|
|
306
|
+
elif hasattr(self, '_log_cdf'):
|
|
307
|
+
self._call_cdf = self._calc_cdf_from_log_cdf
|
|
308
|
+
elif hasattr(self, '_survival_function'):
|
|
309
|
+
self._call_cdf = self._calc_cdf_from_survival
|
|
310
|
+
elif hasattr(self, '_log_survival'):
|
|
311
|
+
self._call_cdf = self._calc_cdf_from_log_survival
|
|
312
|
+
else:
|
|
313
|
+
self._call_cdf = self._raise_not_implemented_error('cdf')
|
|
314
|
+
|
|
315
|
+
def _set_survival(self):
|
|
316
|
+
"""
|
|
317
|
+
Set survival function based on the availability of _survival function and `_log_survival`
|
|
318
|
+
and `_call_cdf`.
|
|
319
|
+
"""
|
|
320
|
+
if not (hasattr(self, '_survival_function') or hasattr(self, '_log_survival') or
|
|
321
|
+
hasattr(self, '_cdf') or hasattr(self, '_log_cdf')):
|
|
322
|
+
self._call_survival = self._raise_not_implemented_error(
|
|
323
|
+
'survival_function')
|
|
324
|
+
elif hasattr(self, '_survival_function'):
|
|
325
|
+
self._call_survival = self._survival_function
|
|
326
|
+
elif hasattr(self, '_log_survival'):
|
|
327
|
+
self._call_survival = self._calc_survival_from_log_survival
|
|
328
|
+
elif hasattr(self, '_call_cdf'):
|
|
329
|
+
self._call_survival = self._calc_survival_from_call_cdf
|
|
330
|
+
|
|
331
|
+
def _set_log_cdf(self):
|
|
332
|
+
"""
|
|
333
|
+
Set log cdf based on the availability of `_log_cdf` and `_call_cdf`.
|
|
334
|
+
"""
|
|
335
|
+
if not (hasattr(self, '_log_cdf') or hasattr(self, '_cdf') or
|
|
336
|
+
hasattr(self, '_survival_function') or hasattr(self, '_log_survival')):
|
|
337
|
+
self._call_log_cdf = self._raise_not_implemented_error('log_cdf')
|
|
338
|
+
elif hasattr(self, '_log_cdf'):
|
|
339
|
+
self._call_log_cdf = self._log_cdf
|
|
340
|
+
elif hasattr(self, '_call_cdf'):
|
|
341
|
+
self._call_log_cdf = self._calc_log_cdf_from_call_cdf
|
|
342
|
+
|
|
343
|
+
def _set_log_survival(self):
|
|
344
|
+
"""
|
|
345
|
+
Set log survival based on the availability of `_log_survival` and `_call_survival`.
|
|
346
|
+
"""
|
|
347
|
+
if not (hasattr(self, '_log_survival') or hasattr(self, '_survival_function') or
|
|
348
|
+
hasattr(self, '_log_cdf') or hasattr(self, '_cdf')):
|
|
349
|
+
self._call_log_survival = self._raise_not_implemented_error(
|
|
350
|
+
'log_cdf')
|
|
351
|
+
elif hasattr(self, '_log_survival'):
|
|
352
|
+
self._call_log_survival = self._log_survival
|
|
353
|
+
elif hasattr(self, '_call_survival'):
|
|
354
|
+
self._call_log_survival = self._calc_log_survival_from_call_survival
|
|
355
|
+
|
|
356
|
+
def _set_cross_entropy(self):
|
|
357
|
+
"""
|
|
358
|
+
Set log survival based on the availability of `_cross_entropy`.
|
|
359
|
+
"""
|
|
360
|
+
if hasattr(self, '_cross_entropy'):
|
|
361
|
+
self._call_cross_entropy = self._cross_entropy
|
|
362
|
+
else:
|
|
363
|
+
self._call_cross_entropy = self._raise_not_implemented_error(
|
|
364
|
+
'cross_entropy')
|
|
365
|
+
|
|
366
|
+
def _get_dist_args(self, *args, **kwargs):
|
|
367
|
+
return raise_not_implemented_util('get_dist_args', self.name, *args, **kwargs)
|
|
368
|
+
|
|
369
|
+
def get_dist_args(self, *args, **kwargs):
|
|
370
|
+
"""
|
|
371
|
+
Check the availability and validity of default parameters and `dist_spec_args`.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
375
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
376
|
+
|
|
377
|
+
Note:
|
|
378
|
+
`dist_spec_args` must be passed in through list or dictionary. The order of `dist_spec_args`
|
|
379
|
+
should follow the initialization order of default parameters through `_add_parameter`.
|
|
380
|
+
If some `dist_spec_args` is None, the corresponding default parameter is returned.
|
|
381
|
+
|
|
382
|
+
Return:
|
|
383
|
+
list[Tensor], the list of parameters.
|
|
384
|
+
"""
|
|
385
|
+
return self._get_dist_args(*args, **kwargs)
|
|
386
|
+
|
|
387
|
+
def _get_dist_type(self):
|
|
388
|
+
return raise_not_implemented_util('get_dist_type', self.name)
|
|
389
|
+
|
|
390
|
+
def get_dist_type(self):
|
|
391
|
+
"""
|
|
392
|
+
Return the type of the distribution.
|
|
393
|
+
|
|
394
|
+
Return:
|
|
395
|
+
string, the name of distribution.
|
|
396
|
+
"""
|
|
397
|
+
return self._get_dist_type()
|
|
398
|
+
|
|
399
|
+
def _raise_not_implemented_error(self, func_name):
|
|
400
|
+
name = self.name
|
|
401
|
+
|
|
402
|
+
def raise_error(*args, **kwargs):
|
|
403
|
+
return raise_not_implemented_util(func_name, name, *args, **kwargs)
|
|
404
|
+
return raise_error
|
|
405
|
+
|
|
406
|
+
def log_prob(self, value, *args, **kwargs):
|
|
407
|
+
"""
|
|
408
|
+
Evaluate the log probability(pdf or pmf) at the given value.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
value (Tensor): value to be evaluated.
|
|
412
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
413
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
414
|
+
|
|
415
|
+
Note:
|
|
416
|
+
A distribution can be optionally passed to the function by passing its `dist_spec_args` through
|
|
417
|
+
`args` or `kwargs`.
|
|
418
|
+
|
|
419
|
+
Return:
|
|
420
|
+
Tensor, the value of log probability.
|
|
421
|
+
"""
|
|
422
|
+
return self._call_log_prob(value, *args, **kwargs)
|
|
423
|
+
|
|
424
|
+
def _calc_prob_from_log_prob(self, value, *args, **kwargs):
|
|
425
|
+
r"""
|
|
426
|
+
Evaluate prob from log probability.
|
|
427
|
+
|
|
428
|
+
.. math::
|
|
429
|
+
probability(x) = \exp(log_likehood(x))
|
|
430
|
+
"""
|
|
431
|
+
return self.exp_base(self._log_prob(value, *args, **kwargs))
|
|
432
|
+
|
|
433
|
+
def prob(self, value, *args, **kwargs):
|
|
434
|
+
"""
|
|
435
|
+
Evaluate the probability (pdf or pmf) at given value. For a discrete distribution,
|
|
436
|
+
it is a probability mass function, while for a continuous distribution, it is probability density function.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
value (Tensor): value to be evaluated.
|
|
440
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
441
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
442
|
+
|
|
443
|
+
Note:
|
|
444
|
+
A distribution can be optionally passed to the function by passing its `dist_spec_args` through
|
|
445
|
+
`args` or `kwargs`.
|
|
446
|
+
|
|
447
|
+
Return:
|
|
448
|
+
Tensor, the value of probability.
|
|
449
|
+
"""
|
|
450
|
+
return self._call_prob(value, *args, **kwargs)
|
|
451
|
+
|
|
452
|
+
def _calc_log_prob_from_prob(self, value, *args, **kwargs):
|
|
453
|
+
r"""
|
|
454
|
+
Evaluate log probability from probability.
|
|
455
|
+
|
|
456
|
+
.. math::
|
|
457
|
+
log_prob(x) = \log(prob(x))
|
|
458
|
+
"""
|
|
459
|
+
return self.log_base(self._prob(value, *args, **kwargs))
|
|
460
|
+
|
|
461
|
+
def cdf(self, value, *args, **kwargs):
|
|
462
|
+
"""
|
|
463
|
+
Evaluate the cumulative distribution function(cdf) at given value.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
value (Tensor): value to be evaluated.
|
|
467
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
468
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
469
|
+
|
|
470
|
+
Note:
|
|
471
|
+
A distribution can be optionally passed to the function by passing its `dist_spec_args` through
|
|
472
|
+
`args` or `kwargs`.
|
|
473
|
+
|
|
474
|
+
Return:
|
|
475
|
+
Tensor, the cdf of the distribution.
|
|
476
|
+
"""
|
|
477
|
+
return self._call_cdf(value, *args, **kwargs)
|
|
478
|
+
|
|
479
|
+
def _calc_cdf_from_log_cdf(self, value, *args, **kwargs):
|
|
480
|
+
r"""
|
|
481
|
+
Evaluate cdf from log_cdf.
|
|
482
|
+
|
|
483
|
+
.. math::
|
|
484
|
+
cdf(x) = \exp(log_cdf(x))
|
|
485
|
+
"""
|
|
486
|
+
return self.exp_base(self._log_cdf(value, *args, **kwargs))
|
|
487
|
+
|
|
488
|
+
def _calc_cdf_from_survival(self, value, *args, **kwargs):
|
|
489
|
+
r"""
|
|
490
|
+
Evaluate cdf from survival function.
|
|
491
|
+
|
|
492
|
+
.. math::
|
|
493
|
+
cdf(x) = 1 - (survival_function(x))
|
|
494
|
+
"""
|
|
495
|
+
return 1.0 - self._survival_function(value, *args, **kwargs)
|
|
496
|
+
|
|
497
|
+
def _calc_cdf_from_log_survival(self, value, *args, **kwargs):
|
|
498
|
+
r"""
|
|
499
|
+
Evaluate cdf from log survival function.
|
|
500
|
+
|
|
501
|
+
.. math::
|
|
502
|
+
cdf(x) = 1 - (\exp(log_survival(x)))
|
|
503
|
+
"""
|
|
504
|
+
return 1.0 - self.exp_base(self._log_survival(value, *args, **kwargs))
|
|
505
|
+
|
|
506
|
+
def log_cdf(self, value, *args, **kwargs):
|
|
507
|
+
"""
|
|
508
|
+
Evaluate the log the cumulative distribution function(cdf) at given value.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
value (Tensor): value to be evaluated.
|
|
512
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
513
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
514
|
+
|
|
515
|
+
Note:
|
|
516
|
+
A distribution can be optionally passed to the function by passing its `dist_spec_args` through
|
|
517
|
+
`args` or `kwargs`.
|
|
518
|
+
|
|
519
|
+
Return:
|
|
520
|
+
Tensor, the log cdf of the distribution.
|
|
521
|
+
"""
|
|
522
|
+
return self._call_log_cdf(value, *args, **kwargs)
|
|
523
|
+
|
|
524
|
+
def _calc_log_cdf_from_call_cdf(self, value, *args, **kwargs):
|
|
525
|
+
r"""
|
|
526
|
+
Evaluate log cdf from cdf.
|
|
527
|
+
|
|
528
|
+
.. math::
|
|
529
|
+
log_cdf(x) = \log(cdf(x))
|
|
530
|
+
"""
|
|
531
|
+
return self.log_base(self._call_cdf(value, *args, **kwargs))
|
|
532
|
+
|
|
533
|
+
def survival_function(self, value, *args, **kwargs):
|
|
534
|
+
"""
|
|
535
|
+
Evaluate the survival function at given value.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
value (Tensor): value to be evaluated.
|
|
539
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
540
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
541
|
+
|
|
542
|
+
Note:
|
|
543
|
+
A distribution can be optionally passed to the function by passing its `dist_spec_args` through
|
|
544
|
+
`args` or `kwargs`.
|
|
545
|
+
|
|
546
|
+
Return:
|
|
547
|
+
Tensor, the survival function of the distribution.
|
|
548
|
+
"""
|
|
549
|
+
return self._call_survival(value, *args, **kwargs)
|
|
550
|
+
|
|
551
|
+
def _calc_survival_from_call_cdf(self, value, *args, **kwargs):
|
|
552
|
+
r"""
|
|
553
|
+
Evaluate survival function from cdf.
|
|
554
|
+
|
|
555
|
+
.. math::
|
|
556
|
+
survival_function(x) = 1 - (cdf(x))
|
|
557
|
+
"""
|
|
558
|
+
return 1.0 - self._call_cdf(value, *args, **kwargs)
|
|
559
|
+
|
|
560
|
+
def _calc_survival_from_log_survival(self, value, *args, **kwargs):
|
|
561
|
+
r"""
|
|
562
|
+
Evaluate survival function from log survival function.
|
|
563
|
+
|
|
564
|
+
.. math::
|
|
565
|
+
survival(x) = \exp(survival_function(x))
|
|
566
|
+
"""
|
|
567
|
+
return self.exp_base(self._log_survival(value, *args, **kwargs))
|
|
568
|
+
|
|
569
|
+
def log_survival(self, value, *args, **kwargs):
|
|
570
|
+
"""
|
|
571
|
+
Evaluate the log survival function at given value.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
value (Tensor): value to be evaluated.
|
|
575
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
576
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
577
|
+
|
|
578
|
+
Note:
|
|
579
|
+
A distribution can be optionally passed to the function by passing its `dist_spec_args` through
|
|
580
|
+
`args` or `kwargs`.
|
|
581
|
+
|
|
582
|
+
Return:
|
|
583
|
+
Tensor, the log survival function of the distribution.
|
|
584
|
+
"""
|
|
585
|
+
return self._call_log_survival(value, *args, **kwargs)
|
|
586
|
+
|
|
587
|
+
def _calc_log_survival_from_call_survival(self, value, *args, **kwargs):
|
|
588
|
+
r"""
|
|
589
|
+
Evaluate log survival function from survival function.
|
|
590
|
+
|
|
591
|
+
.. math::
|
|
592
|
+
log_survival(x) = \log(survival_function(x))
|
|
593
|
+
"""
|
|
594
|
+
return self.log_base(self._call_survival(value, *args, **kwargs))
|
|
595
|
+
|
|
596
|
+
def _kl_loss(self, *args, **kwargs):
|
|
597
|
+
return raise_not_implemented_util('kl_loss', self.name, *args, **kwargs)
|
|
598
|
+
|
|
599
|
+
def kl_loss(self, dist, *args, **kwargs):
|
|
600
|
+
"""
|
|
601
|
+
Evaluate the KL divergence, i.e. KL(a||b).
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
dist (str): type of the distribution.
|
|
605
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
606
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
607
|
+
|
|
608
|
+
Note:
|
|
609
|
+
`dist_spec_args` of distribution b must be passed to the function through `args` or `kwargs`.
|
|
610
|
+
Passing in `dist_spec_args` of distribution a is optional.
|
|
611
|
+
|
|
612
|
+
Return:
|
|
613
|
+
Tensor, the kl loss function of the distribution.
|
|
614
|
+
"""
|
|
615
|
+
return self._kl_loss(dist, *args, **kwargs)
|
|
616
|
+
|
|
617
|
+
def _mean(self, *args, **kwargs):
|
|
618
|
+
return raise_not_implemented_util('mean', self.name, *args, **kwargs)
|
|
619
|
+
|
|
620
|
+
def mean(self, *args, **kwargs):
|
|
621
|
+
"""
|
|
622
|
+
Evaluate the mean.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
626
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
627
|
+
|
|
628
|
+
Note:
|
|
629
|
+
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
|
630
|
+
`args` or `kwargs`.
|
|
631
|
+
|
|
632
|
+
Return:
|
|
633
|
+
Tensor, the mean of the distribution.
|
|
634
|
+
"""
|
|
635
|
+
return self._mean(*args, **kwargs)
|
|
636
|
+
|
|
637
|
+
def _mode(self, *args, **kwargs):
|
|
638
|
+
return raise_not_implemented_util('mode', self.name, *args, **kwargs)
|
|
639
|
+
|
|
640
|
+
def mode(self, *args, **kwargs):
|
|
641
|
+
"""
|
|
642
|
+
Evaluate the mode.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
646
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
647
|
+
|
|
648
|
+
Note:
|
|
649
|
+
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
|
650
|
+
`args` or `kwargs`.
|
|
651
|
+
|
|
652
|
+
Return:
|
|
653
|
+
Tensor, the mode of the distribution.
|
|
654
|
+
"""
|
|
655
|
+
return self._mode(*args, **kwargs)
|
|
656
|
+
|
|
657
|
+
def sd(self, *args, **kwargs):
|
|
658
|
+
"""
|
|
659
|
+
Evaluate the standard deviation.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
663
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
664
|
+
|
|
665
|
+
Note:
|
|
666
|
+
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
|
667
|
+
`args` or `kwargs`.
|
|
668
|
+
|
|
669
|
+
Return:
|
|
670
|
+
Tensor, the standard deviation of the distribution.
|
|
671
|
+
"""
|
|
672
|
+
return self._call_sd(*args, **kwargs)
|
|
673
|
+
|
|
674
|
+
def var(self, *args, **kwargs):
|
|
675
|
+
"""
|
|
676
|
+
Evaluate the variance.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
680
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
681
|
+
|
|
682
|
+
Note:
|
|
683
|
+
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
|
684
|
+
`args` or `kwargs`.
|
|
685
|
+
|
|
686
|
+
Return:
|
|
687
|
+
Tensor, the variance of the distribution.
|
|
688
|
+
"""
|
|
689
|
+
return self._call_var(*args, **kwargs)
|
|
690
|
+
|
|
691
|
+
def _calc_sd_from_var(self, *args, **kwargs):
|
|
692
|
+
r"""
|
|
693
|
+
Evaluate log probability from probability.
|
|
694
|
+
|
|
695
|
+
.. math::
|
|
696
|
+
STD(x) = \sqrt(VAR(x))
|
|
697
|
+
"""
|
|
698
|
+
return self.sqrt_base(self._var(*args, **kwargs))
|
|
699
|
+
|
|
700
|
+
def _calc_var_from_sd(self, *args, **kwargs):
|
|
701
|
+
r"""
|
|
702
|
+
Evaluate log probability from probability.
|
|
703
|
+
|
|
704
|
+
.. math::
|
|
705
|
+
VAR(x) = STD(x) ^ 2
|
|
706
|
+
"""
|
|
707
|
+
return self.sq_base(self._sd(*args, **kwargs))
|
|
708
|
+
|
|
709
|
+
def _entropy(self, *args, **kwargs):
|
|
710
|
+
return raise_not_implemented_util('entropy', self.name, *args, **kwargs)
|
|
711
|
+
|
|
712
|
+
def entropy(self, *args, **kwargs):
|
|
713
|
+
"""
|
|
714
|
+
Evaluate the entropy.
|
|
715
|
+
|
|
716
|
+
Args:
|
|
717
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
718
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
719
|
+
|
|
720
|
+
Note:
|
|
721
|
+
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
|
722
|
+
`args` or `kwargs`.
|
|
723
|
+
|
|
724
|
+
Return:
|
|
725
|
+
Tensor, the entropy of the distribution.
|
|
726
|
+
"""
|
|
727
|
+
return self._entropy(*args, **kwargs)
|
|
728
|
+
|
|
729
|
+
def cross_entropy(self, dist, *args, **kwargs):
|
|
730
|
+
"""
|
|
731
|
+
Evaluate the cross_entropy between distribution a and b.
|
|
732
|
+
|
|
733
|
+
Args:
|
|
734
|
+
dist (str): type of the distribution.
|
|
735
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
736
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
737
|
+
|
|
738
|
+
Note:
|
|
739
|
+
`dist_spec_args` of distribution b must be passed to the function through `args` or `kwargs`.
|
|
740
|
+
Passing in `dist_spec_args` of distribution a is optional.
|
|
741
|
+
|
|
742
|
+
Return:
|
|
743
|
+
Tensor, the cross_entropy of two distributions.
|
|
744
|
+
"""
|
|
745
|
+
return self._call_cross_entropy(dist, *args, **kwargs)
|
|
746
|
+
|
|
747
|
+
def _calc_cross_entropy(self, dist, *args, **kwargs):
|
|
748
|
+
r"""
|
|
749
|
+
Evaluate cross_entropy from entropy and kl divergence.
|
|
750
|
+
|
|
751
|
+
.. math::
|
|
752
|
+
H(X, Y) = H(X) + KL(X||Y)
|
|
753
|
+
"""
|
|
754
|
+
return self._entropy(*args, **kwargs) + self._kl_loss(dist, *args, **kwargs)
|
|
755
|
+
|
|
756
|
+
def _sample(self, *args, **kwargs):
|
|
757
|
+
return raise_not_implemented_util('sample', self.name, *args, **kwargs)
|
|
758
|
+
|
|
759
|
+
def sample(self, *args, **kwargs):
|
|
760
|
+
"""
|
|
761
|
+
Sampling function.
|
|
762
|
+
|
|
763
|
+
Args:
|
|
764
|
+
*args (list): the list of positional arguments forwarded to subclasses.
|
|
765
|
+
**kwargs (dict): the dictionary of keyword arguments forwarded to subclasses.
|
|
766
|
+
|
|
767
|
+
Note:
|
|
768
|
+
A distribution can be optionally passed to the function by passing its *dist_spec_args* through
|
|
769
|
+
`args` or `kwargs`.
|
|
770
|
+
|
|
771
|
+
Return:
|
|
772
|
+
Tensor, the sample generated from the distribution.
|
|
773
|
+
"""
|
|
774
|
+
return self._sample(*args, **kwargs)
|
|
775
|
+
|
|
776
|
+
def construct(self, name, *args, **kwargs):
|
|
777
|
+
"""
|
|
778
|
+
Override `construct` in Cell.
|
|
779
|
+
|
|
780
|
+
Note:
|
|
781
|
+
Names of supported functions include:
|
|
782
|
+
'prob', 'log_prob', 'cdf', 'log_cdf', 'survival_function', 'log_survival',
|
|
783
|
+
'var', 'sd', 'mode', 'mean', 'entropy', 'kl_loss', 'cross_entropy', 'sample',
|
|
784
|
+
'get_dist_args', and 'get_dist_type'.
|
|
785
|
+
|
|
786
|
+
Args:
|
|
787
|
+
name (str): The name of the function.
|
|
788
|
+
*args (list): A list of positional arguments that the function needs.
|
|
789
|
+
**kwargs (dict): A dictionary of keyword arguments that the function needs.
|
|
790
|
+
|
|
791
|
+
Return:
|
|
792
|
+
Tensor, the value of corresponding computation method.
|
|
793
|
+
"""
|
|
794
|
+
|
|
795
|
+
if name == 'log_prob':
|
|
796
|
+
return self._call_log_prob(*args, **kwargs)
|
|
797
|
+
if name == 'prob':
|
|
798
|
+
return self._call_prob(*args, **kwargs)
|
|
799
|
+
if name == 'cdf':
|
|
800
|
+
return self._call_cdf(*args, **kwargs)
|
|
801
|
+
if name == 'log_cdf':
|
|
802
|
+
return self._call_log_cdf(*args, **kwargs)
|
|
803
|
+
if name == 'survival_function':
|
|
804
|
+
return self._call_survival(*args, **kwargs)
|
|
805
|
+
if name == 'log_survival':
|
|
806
|
+
return self._call_log_survival(*args, **kwargs)
|
|
807
|
+
if name == 'kl_loss':
|
|
808
|
+
return self._kl_loss(*args, **kwargs)
|
|
809
|
+
if name == 'mean':
|
|
810
|
+
return self._mean(*args, **kwargs)
|
|
811
|
+
if name == 'mode':
|
|
812
|
+
return self._mode(*args, **kwargs)
|
|
813
|
+
if name == 'sd':
|
|
814
|
+
return self._call_sd(*args, **kwargs)
|
|
815
|
+
if name == 'var':
|
|
816
|
+
return self._call_var(*args, **kwargs)
|
|
817
|
+
if name == 'entropy':
|
|
818
|
+
return self._entropy(*args, **kwargs)
|
|
819
|
+
if name == 'cross_entropy':
|
|
820
|
+
return self._call_cross_entropy(*args, **kwargs)
|
|
821
|
+
if name == 'sample':
|
|
822
|
+
return self._sample(*args, **kwargs)
|
|
823
|
+
if name == 'get_dist_args':
|
|
824
|
+
return self._get_dist_args(*args, **kwargs)
|
|
825
|
+
if name == 'get_dist_type':
|
|
826
|
+
return self._get_dist_type()
|
|
827
|
+
return raise_not_implemented_util(name, self.name, *args, **kwargs)
|