mindspore 2.1.0__cp37-none-any.whl → 2.2.11__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +476 -527
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/nn/optim/thor.py
CHANGED
|
@@ -266,10 +266,10 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
266
266
|
\otimes\left(G_{i}^{(k)}+\lambda I\right)^{-1}\right) \nabla_{w_{i}} J^{(k)}
|
|
267
267
|
\end{array}
|
|
268
268
|
|
|
269
|
-
:math:`a_{i-1}` represents the input of i
|
|
270
|
-
:math:`D_{s_i}` represents the derivative of the loss function of the output of the i
|
|
269
|
+
:math:`a_{i-1}` represents the input of :math:`i`-th layer,and which is the activations of previous layer.
|
|
270
|
+
:math:`D_{s_i}` represents the derivative of the loss function of the output of the :math:`i`-th layer.
|
|
271
271
|
:math:`I` represents the identity matrix.
|
|
272
|
-
:math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the i
|
|
272
|
+
:math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the :math:`i`-th layer.
|
|
273
273
|
:math:`\otimes` represents Kronecker product, :math:`\gamma` represents 'learning rate'.
|
|
274
274
|
|
|
275
275
|
Note:
|
|
@@ -339,10 +339,10 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
339
339
|
>>> from mindspore import Tensor
|
|
340
340
|
>>>
|
|
341
341
|
>>> # Define the network structure of LeNet5. Refer to
|
|
342
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
342
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
343
343
|
>>> net = LeNet5()
|
|
344
344
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
345
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
345
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
346
346
|
>>> dataset = create_dataset()
|
|
347
347
|
>>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
|
|
348
348
|
>>> optim = nn.thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
|
|
@@ -424,7 +424,7 @@ class ThorGpu(Optimizer):
|
|
|
424
424
|
self.matmul = P.MatMul()
|
|
425
425
|
self.assign = P.Assign()
|
|
426
426
|
self.mul = P.Mul()
|
|
427
|
-
self.gather = P.
|
|
427
|
+
self.gather = P.Gather()
|
|
428
428
|
self.one = Tensor(1, mstype.int32)
|
|
429
429
|
self.feature_map = Tensor(1.0, mstype.float32)
|
|
430
430
|
self.axis = 0
|
|
@@ -653,6 +653,7 @@ class ThorGpu(Optimizer):
|
|
|
653
653
|
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
|
|
654
654
|
gradients = clip_gradient(self.enable_clip_grad, gradients)
|
|
655
655
|
lr = self.get_lr()
|
|
656
|
+
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
656
657
|
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
|
657
658
|
return success
|
|
658
659
|
|
|
@@ -739,7 +740,7 @@ class ThorAscend(Optimizer):
|
|
|
739
740
|
self.log = P.Log()
|
|
740
741
|
self.exp = P.Exp()
|
|
741
742
|
self.sqrt = P.Sqrt()
|
|
742
|
-
self.gather = P.
|
|
743
|
+
self.gather = P.Gather()
|
|
743
744
|
self.assign = P.Assign()
|
|
744
745
|
self.cast = P.Cast()
|
|
745
746
|
self.eye = P.Eye()
|
|
@@ -1304,5 +1305,6 @@ class ThorAscend(Optimizer):
|
|
|
1304
1305
|
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
|
|
1305
1306
|
gradients = clip_gradient(self.enable_clip_grad, gradients)
|
|
1306
1307
|
lr = self.get_lr()
|
|
1308
|
+
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
1307
1309
|
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
|
1308
1310
|
return success
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from mindspore import context
|
|
17
17
|
from mindspore.nn.cell import Cell
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
|
+
from mindspore.ops import functional as F
|
|
19
20
|
from mindspore.ops.operations import _inner_ops as inner
|
|
20
21
|
from mindspore.common import dtype as mstype
|
|
21
22
|
from mindspore.common.tensor import Tensor
|
|
@@ -96,7 +97,6 @@ class Bijector(Cell):
|
|
|
96
97
|
self.cast_base = P.Cast()
|
|
97
98
|
self.dtype_base = P.DType()
|
|
98
99
|
self.shape_base = P.Shape()
|
|
99
|
-
self.fill_base = P.Fill()
|
|
100
100
|
self.sametypeshape_base = inner.SameTypeShape()
|
|
101
101
|
self.issubclass_base = inner.IsSubClass()
|
|
102
102
|
|
|
@@ -140,13 +140,13 @@ class Bijector(Cell):
|
|
|
140
140
|
if self.issubclass_base(value_type, mstype.float_):
|
|
141
141
|
return value
|
|
142
142
|
return raise_type_error('input value of bijector', value_type, mstype.float_)
|
|
143
|
-
dtype_tensor =
|
|
143
|
+
dtype_tensor = F.fill(self.dtype, self.shape_base(value), 0.0)
|
|
144
144
|
self.sametypeshape_base(value, dtype_tensor)
|
|
145
145
|
return value
|
|
146
146
|
|
|
147
147
|
def _shape_mapping(self, shape):
|
|
148
|
-
shape_tensor =
|
|
149
|
-
dist_shape_tensor =
|
|
148
|
+
shape_tensor = F.fill(self.parameter_type, shape, 0.0)
|
|
149
|
+
dist_shape_tensor = F.fill(
|
|
150
150
|
self.parameter_type, self.batch_shape, 0.0)
|
|
151
151
|
return (shape_tensor + dist_shape_tensor).shape
|
|
152
152
|
|
|
@@ -165,7 +165,7 @@ class Bijector(Cell):
|
|
|
165
165
|
self.common_dtype = None
|
|
166
166
|
# cast value to a tensor if it is not None
|
|
167
167
|
if isinstance(value, bool) or value is None:
|
|
168
|
-
raise TypeError("{} cannot be type {
|
|
168
|
+
raise TypeError(f"{name} cannot be type {type(value)}")
|
|
169
169
|
value_t = Tensor(value)
|
|
170
170
|
# if the bijector's dtype is not specified
|
|
171
171
|
if self.dtype is None:
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""PowerTransform Bijector"""
|
|
16
16
|
from mindspore.ops import operations as P
|
|
17
|
+
from mindspore.ops import functional as F
|
|
17
18
|
from ..distribution._utils.utils import check_greater_equal_zero
|
|
18
19
|
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
|
19
20
|
from .bijector import Bijector
|
|
@@ -68,10 +69,7 @@ class PowerTransform(Bijector):
|
|
|
68
69
|
>>> print(ans4.shape)
|
|
69
70
|
(3,)
|
|
70
71
|
"""
|
|
71
|
-
|
|
72
|
-
def __init__(self,
|
|
73
|
-
power=0.,
|
|
74
|
-
name='PowerTransform'):
|
|
72
|
+
def __init__(self, power=0., name='PowerTransform'):
|
|
75
73
|
param = dict(locals())
|
|
76
74
|
param['param_dict'] = {'power': power}
|
|
77
75
|
super(PowerTransform, self).__init__(name=name, param=param)
|
|
@@ -84,7 +82,6 @@ class PowerTransform(Bijector):
|
|
|
84
82
|
self.equal_base = P.Equal()
|
|
85
83
|
self.exp = exp_generic
|
|
86
84
|
self.expm1 = P.Expm1()
|
|
87
|
-
self.fill = P.Fill()
|
|
88
85
|
self.log = log_generic
|
|
89
86
|
self.log1p = P.Log1p()
|
|
90
87
|
self.select_base = P.Select()
|
|
@@ -116,17 +113,18 @@ class PowerTransform(Bijector):
|
|
|
116
113
|
power_local = self.cast_param_by_value(x, self.power)
|
|
117
114
|
|
|
118
115
|
# broad cast the value of x and power
|
|
119
|
-
ones =
|
|
120
|
-
|
|
116
|
+
ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
|
|
117
|
+
1.)
|
|
121
118
|
power_local = power_local * ones
|
|
122
119
|
x = x * ones
|
|
123
|
-
safe_power = self.select_base(
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
forward_v = self.select_base(
|
|
128
|
-
|
|
129
|
-
|
|
120
|
+
safe_power = self.select_base(
|
|
121
|
+
self.equal_base(power_local,
|
|
122
|
+
P.ZerosLike()(power_local)), ones, power_local)
|
|
123
|
+
|
|
124
|
+
forward_v = self.select_base(
|
|
125
|
+
self.equal_base(power_local,
|
|
126
|
+
P.ZerosLike()(power_local)), self.exp(x),
|
|
127
|
+
self.exp(self.log1p(x * safe_power) / safe_power))
|
|
130
128
|
return forward_v
|
|
131
129
|
|
|
132
130
|
def _inverse(self, y):
|
|
@@ -137,17 +135,18 @@ class PowerTransform(Bijector):
|
|
|
137
135
|
power_local = self.cast_param_by_value(y, self.power)
|
|
138
136
|
|
|
139
137
|
# broad cast the value of x and power
|
|
140
|
-
ones =
|
|
141
|
-
|
|
138
|
+
ones = F.fill(self.dtypeop(power_local), self.shape(y + power_local),
|
|
139
|
+
1.)
|
|
142
140
|
power_local = power_local * ones
|
|
143
141
|
y = y * ones
|
|
144
|
-
safe_power = self.select_base(
|
|
145
|
-
|
|
146
|
-
|
|
142
|
+
safe_power = self.select_base(
|
|
143
|
+
self.equal_base(power_local,
|
|
144
|
+
P.ZerosLike()(power_local)), ones, power_local)
|
|
147
145
|
|
|
148
|
-
inverse_v = self.select_base(
|
|
149
|
-
|
|
150
|
-
|
|
146
|
+
inverse_v = self.select_base(
|
|
147
|
+
self.equal_base(power_local,
|
|
148
|
+
P.ZerosLike()(power_local)), self.log(y),
|
|
149
|
+
self.expm1(self.log(y) * safe_power) / safe_power)
|
|
151
150
|
|
|
152
151
|
return inverse_v
|
|
153
152
|
|
|
@@ -167,14 +166,15 @@ class PowerTransform(Bijector):
|
|
|
167
166
|
power_local = self.cast_param_by_value(x, self.power)
|
|
168
167
|
|
|
169
168
|
# broad cast the value of x and power
|
|
170
|
-
ones =
|
|
171
|
-
|
|
169
|
+
ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
|
|
170
|
+
1.)
|
|
172
171
|
power_local = power_local * ones
|
|
173
172
|
x = x * ones
|
|
174
173
|
|
|
175
|
-
forward_log_j = self.select_base(
|
|
176
|
-
|
|
177
|
-
|
|
174
|
+
forward_log_j = self.select_base(
|
|
175
|
+
self.equal_base(power_local,
|
|
176
|
+
P.ZerosLike()(power_local)), x,
|
|
177
|
+
(1. / power_local - 1) * self.log1p(x * power_local))
|
|
178
178
|
|
|
179
179
|
return forward_log_j
|
|
180
180
|
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Softplus Bijector"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.nn.layer.activation import LogSigmoid
|
|
19
20
|
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
|
20
21
|
from .bijector import Bijector
|
|
@@ -84,7 +85,6 @@ class Softplus(Bijector):
|
|
|
84
85
|
self.abs = P.Abs()
|
|
85
86
|
self.dtypeop = P.DType()
|
|
86
87
|
self.cast = P.Cast()
|
|
87
|
-
self.fill = P.Fill()
|
|
88
88
|
self.greater = P.Greater()
|
|
89
89
|
self.less = P.Less()
|
|
90
90
|
self.log_sigmoid = LogSigmoid()
|
|
@@ -103,7 +103,7 @@ class Softplus(Bijector):
|
|
|
103
103
|
too_large = self.greater(x, -self.threshold)
|
|
104
104
|
too_small_value = self.exp(x)
|
|
105
105
|
too_large_value = x
|
|
106
|
-
ones =
|
|
106
|
+
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
107
107
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
108
108
|
x = self.select(too_small_or_too_large, ones, x)
|
|
109
109
|
y = self.log(self.exp(x) + 1.0)
|
|
@@ -119,7 +119,7 @@ class Softplus(Bijector):
|
|
|
119
119
|
too_large = self.greater(x, (-1) * self.threshold)
|
|
120
120
|
too_small_value = self.log(x)
|
|
121
121
|
too_large_value = x
|
|
122
|
-
ones =
|
|
122
|
+
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
123
123
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
124
124
|
x = self.select(too_small_or_too_large, ones, x)
|
|
125
125
|
y = x + self.log(self.abs(self.expm1((-1)*x)))
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Utility functions to help distribution class."""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops.operations import _inner_ops as inner
|
|
19
20
|
from mindspore.ops.primitive import constexpr
|
|
20
21
|
from mindspore.common import dtype as mstype
|
|
@@ -52,7 +53,6 @@ def log_generic(input_x):
|
|
|
52
53
|
log = P.Log()
|
|
53
54
|
less = P.Less()
|
|
54
55
|
lessequal = P.LessEqual()
|
|
55
|
-
fill = P.Fill()
|
|
56
56
|
cast = P.Cast()
|
|
57
57
|
dtype = P.DType()
|
|
58
58
|
shape = P.Shape()
|
|
@@ -61,8 +61,8 @@ def log_generic(input_x):
|
|
|
61
61
|
|
|
62
62
|
if not checktype(dtype(input_x), mstype.float_):
|
|
63
63
|
input_x = cast(input_x, mstype.float32)
|
|
64
|
-
nan = fill(dtype(input_x), shape(input_x), np.nan)
|
|
65
|
-
inf = fill(dtype(input_x), shape(input_x), np.inf)
|
|
64
|
+
nan = F.fill(dtype(input_x), shape(input_x), np.nan)
|
|
65
|
+
inf = F.fill(dtype(input_x), shape(input_x), np.inf)
|
|
66
66
|
neg_x = less(input_x, 0.0)
|
|
67
67
|
nonpos_x = lessequal(input_x, 0.0)
|
|
68
68
|
log_x = log(input_x)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Bernoulli Distribution"""
|
|
16
16
|
from mindspore.common import dtype as mstype
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
from mindspore import _checkparam as Validator
|
|
20
21
|
from .distribution import Distribution
|
|
@@ -151,7 +152,6 @@ class Bernoulli(Distribution):
|
|
|
151
152
|
self.cast = P.Cast()
|
|
152
153
|
self.const = P.ScalarToTensor()
|
|
153
154
|
self.floor = P.Floor()
|
|
154
|
-
self.fill = P.Fill()
|
|
155
155
|
self.less = P.Less()
|
|
156
156
|
self.shape = P.Shape()
|
|
157
157
|
self.select = P.Select()
|
|
@@ -200,8 +200,8 @@ class Bernoulli(Distribution):
|
|
|
200
200
|
MODE(B) = 1 if probs1 > 0.5 else = 0
|
|
201
201
|
"""
|
|
202
202
|
probs1 = self._check_param_type(probs1)
|
|
203
|
-
zeros =
|
|
204
|
-
ones =
|
|
203
|
+
zeros = F.fill(self.dtype, self.shape(probs1), 0.0)
|
|
204
|
+
ones = F.fill(self.dtype, self.shape(probs1), 1.0)
|
|
205
205
|
comp = self.less(0.5, probs1)
|
|
206
206
|
return self.select(comp, ones, zeros)
|
|
207
207
|
|
|
@@ -278,9 +278,9 @@ class Bernoulli(Distribution):
|
|
|
278
278
|
probs0 = self.broadcast((1.0 - probs1), broadcast_shape_tensor)
|
|
279
279
|
comp_zero = self.less(value, 0.0)
|
|
280
280
|
comp_one = self.less(value, 1.0)
|
|
281
|
-
zeros =
|
|
281
|
+
zeros = F.fill(self.parameter_type, self.shape(
|
|
282
282
|
broadcast_shape_tensor), 0.0)
|
|
283
|
-
ones =
|
|
283
|
+
ones = F.fill(self.parameter_type, self.shape(
|
|
284
284
|
broadcast_shape_tensor), 1.0)
|
|
285
285
|
less_than_zero = self.select(comp_zero, zeros, probs0)
|
|
286
286
|
return self.select(comp_one, less_than_zero, ones)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Beta Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
import mindspore.nn as nn
|
|
20
21
|
from mindspore import _checkparam as Validator
|
|
@@ -186,7 +187,6 @@ class Beta(Distribution):
|
|
|
186
187
|
self.pow = P.Pow()
|
|
187
188
|
self.squeeze = P.Squeeze(0)
|
|
188
189
|
self.cast = P.Cast()
|
|
189
|
-
self.fill = P.Fill()
|
|
190
190
|
self.shape = P.Shape()
|
|
191
191
|
self.select = P.Select()
|
|
192
192
|
self.logicaland = P.LogicalAnd()
|
|
@@ -266,7 +266,7 @@ class Beta(Distribution):
|
|
|
266
266
|
comp2 = self.greater(concentration0, 1.)
|
|
267
267
|
cond = self.logicaland(comp1, comp2)
|
|
268
268
|
batch_shape = self.shape(concentration1 + concentration0)
|
|
269
|
-
nan =
|
|
269
|
+
nan = F.fill(self.dtype, batch_shape, np.nan)
|
|
270
270
|
mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
|
|
271
271
|
return self.select(cond, mode, nan)
|
|
272
272
|
|
|
@@ -379,7 +379,7 @@ class Beta(Distribution):
|
|
|
379
379
|
sample_shape = (1,)
|
|
380
380
|
else:
|
|
381
381
|
sample_shape = origin_shape
|
|
382
|
-
ones =
|
|
382
|
+
ones = F.fill(self.dtype, sample_shape, 1.0)
|
|
383
383
|
sample_gamma1 = C.gamma(
|
|
384
384
|
sample_shape, alpha=concentration1, beta=ones, seed=self.seed)
|
|
385
385
|
sample_gamma2 = C.gamma(
|
|
@@ -17,6 +17,7 @@ import numpy as np
|
|
|
17
17
|
from mindspore import context
|
|
18
18
|
from mindspore.common import Tensor
|
|
19
19
|
from mindspore.ops import operations as P
|
|
20
|
+
from mindspore.ops import functional as F
|
|
20
21
|
from mindspore.ops import composite as C
|
|
21
22
|
from mindspore.ops.functional import stop_gradient
|
|
22
23
|
from mindspore.ops.operations import _inner_ops as inner
|
|
@@ -149,7 +150,6 @@ class Categorical(Distribution):
|
|
|
149
150
|
self.dtypeop = P.DType()
|
|
150
151
|
self.exp = exp_generic
|
|
151
152
|
self.expand_dim = P.ExpandDims()
|
|
152
|
-
self.fill = P.Fill()
|
|
153
153
|
self.gather = P.GatherNd()
|
|
154
154
|
self.greater = P.Greater()
|
|
155
155
|
self.issubclass = inner.IsSubClass()
|
|
@@ -292,7 +292,7 @@ class Categorical(Distribution):
|
|
|
292
292
|
# here we simulate casting to int but still keeping float dtype
|
|
293
293
|
value = self.cast(value, self.dtypeop(probs))
|
|
294
294
|
|
|
295
|
-
zeros =
|
|
295
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
296
296
|
between_zero_neone = self.logicand(self.less(value, 0,),
|
|
297
297
|
self.greater(value, -1.))
|
|
298
298
|
value = self.select(between_zero_neone,
|
|
@@ -338,8 +338,8 @@ class Categorical(Distribution):
|
|
|
338
338
|
# reshape into label shape N
|
|
339
339
|
logits_pmf = self.gather(self.reshape(
|
|
340
340
|
logits, (-1, num_classes)), index)
|
|
341
|
-
nan =
|
|
342
|
-
|
|
341
|
+
nan = F.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf),
|
|
342
|
+
self.nan)
|
|
343
343
|
logits_pmf = self.select(out_of_bound, nan, logits_pmf)
|
|
344
344
|
ans = self.reshape(logits_pmf, label_shape)
|
|
345
345
|
if drop_dim:
|
|
@@ -359,7 +359,7 @@ class Categorical(Distribution):
|
|
|
359
359
|
|
|
360
360
|
value = self.cast(value, self.dtypeop(probs))
|
|
361
361
|
|
|
362
|
-
zeros =
|
|
362
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
363
363
|
between_zero_neone = self.logicand(
|
|
364
364
|
self.less(value, 0,), self.greater(value, -1.))
|
|
365
365
|
value = self.select(between_zero_neone, zeros, P.Floor()(value))
|
|
@@ -394,7 +394,7 @@ class Categorical(Distribution):
|
|
|
394
394
|
# reshape probs and fill less_than_zero places with 0
|
|
395
395
|
probs = self.reshape(probs, (-1, num_classes))
|
|
396
396
|
cdf = self.gather(self.cumsum(probs, 1), index)
|
|
397
|
-
zeros =
|
|
397
|
+
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
|
398
398
|
cdf = self.select(less_than_zero, zeros, cdf)
|
|
399
399
|
cdf = self.reshape(cdf, label_shape)
|
|
400
400
|
|
|
@@ -425,7 +425,7 @@ class Categorical(Distribution):
|
|
|
425
425
|
sample_shape = (1,)
|
|
426
426
|
|
|
427
427
|
probs_2d = self.reshape(probs, (-1, num_classes))
|
|
428
|
-
sample_tensor =
|
|
428
|
+
sample_tensor = F.fill(self.dtype, shape, 1.0)
|
|
429
429
|
sample_tensor = self.reshape(sample_tensor, (-1, 1))
|
|
430
430
|
num_sample = self.shape(sample_tensor)[0]
|
|
431
431
|
samples = C.multinomial(probs_2d, num_sample, seed=self.seed)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""basic"""
|
|
16
16
|
from mindspore import context
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.nn.cell import Cell
|
|
19
20
|
from mindspore.ops.primitive import constexpr
|
|
20
21
|
from mindspore.ops.operations import _inner_ops as inner
|
|
@@ -113,7 +114,6 @@ class Distribution(Cell):
|
|
|
113
114
|
# ops needed for the base class
|
|
114
115
|
self.cast_base = P.Cast()
|
|
115
116
|
self.dtype_base = P.DType()
|
|
116
|
-
self.fill_base = P.Fill()
|
|
117
117
|
self.sametypeshape_base = inner.SameTypeShape()
|
|
118
118
|
self.sq_base = P.Square()
|
|
119
119
|
self.sqrt_base = P.Sqrt()
|
|
@@ -194,11 +194,11 @@ class Distribution(Cell):
|
|
|
194
194
|
if broadcast_shape is None:
|
|
195
195
|
broadcast_shape = self.shape_base(arg)
|
|
196
196
|
common_dtype = self.dtype_base(arg)
|
|
197
|
-
broadcast_shape_tensor =
|
|
197
|
+
broadcast_shape_tensor = F.fill(
|
|
198
198
|
common_dtype, broadcast_shape, 1.0)
|
|
199
199
|
else:
|
|
200
200
|
broadcast_shape = self.shape_base(arg + broadcast_shape_tensor)
|
|
201
|
-
broadcast_shape_tensor =
|
|
201
|
+
broadcast_shape_tensor = F.fill(
|
|
202
202
|
common_dtype, broadcast_shape, 1.0)
|
|
203
203
|
arg = self.broadcast(arg, broadcast_shape_tensor)
|
|
204
204
|
# check if the arguments have the same dtype
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Gamma Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
import mindspore.nn as nn
|
|
20
21
|
from mindspore import _checkparam as Validator
|
|
@@ -185,7 +186,6 @@ class Gamma(Distribution):
|
|
|
185
186
|
self.squeeze = P.Squeeze(0)
|
|
186
187
|
self.cast = P.Cast()
|
|
187
188
|
self.dtypeop = P.DType()
|
|
188
|
-
self.fill = P.Fill()
|
|
189
189
|
self.shape = P.Shape()
|
|
190
190
|
self.select = P.Select()
|
|
191
191
|
self.greater = P.Greater()
|
|
@@ -265,8 +265,8 @@ class Gamma(Distribution):
|
|
|
265
265
|
"""
|
|
266
266
|
concentration, rate = self._check_param_type(concentration, rate)
|
|
267
267
|
mode = (concentration - 1.) / rate
|
|
268
|
-
nan =
|
|
269
|
-
|
|
268
|
+
nan = F.fill(self.dtypeop(concentration), self.shape(concentration),
|
|
269
|
+
np.nan)
|
|
270
270
|
comp = self.greater(concentration, 1.)
|
|
271
271
|
return self.select(comp, mode, nan)
|
|
272
272
|
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Geometric Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops.operations import _inner_ops as inner
|
|
19
20
|
from mindspore.ops import composite as C
|
|
20
21
|
from mindspore import _checkparam as Validator
|
|
@@ -160,7 +161,6 @@ class Geometric(Distribution):
|
|
|
160
161
|
self.cast = P.Cast()
|
|
161
162
|
self.const = P.ScalarToTensor()
|
|
162
163
|
self.dtypeop = P.DType()
|
|
163
|
-
self.fill = P.Fill()
|
|
164
164
|
self.floor = P.Floor()
|
|
165
165
|
self.issubclass = inner.IsSubClass()
|
|
166
166
|
self.less = P.Less()
|
|
@@ -212,7 +212,7 @@ class Geometric(Distribution):
|
|
|
212
212
|
MODE(Geo) = 0
|
|
213
213
|
"""
|
|
214
214
|
probs1 = self._check_param_type(probs1)
|
|
215
|
-
return
|
|
215
|
+
return F.fill(self.dtype, self.shape(probs1), 0.)
|
|
216
216
|
|
|
217
217
|
def _var(self, probs1=None):
|
|
218
218
|
r"""
|
|
@@ -260,7 +260,7 @@ class Geometric(Distribution):
|
|
|
260
260
|
value = self.floor(value)
|
|
261
261
|
probs1 = self._check_param_type(probs1)
|
|
262
262
|
pmf = self.exp(self.log(1.0 - probs1) * value + self.log(probs1))
|
|
263
|
-
zeros =
|
|
263
|
+
zeros = F.fill(self.dtypeop(pmf), self.shape(pmf), 0.0)
|
|
264
264
|
comp = self.less(value, zeros)
|
|
265
265
|
return self.select(comp, zeros, pmf)
|
|
266
266
|
|
|
@@ -283,7 +283,7 @@ class Geometric(Distribution):
|
|
|
283
283
|
probs1 = self._check_param_type(probs1)
|
|
284
284
|
probs0 = 1.0 - probs1
|
|
285
285
|
cdf = 1.0 - self.pow(probs0, value + 1.0)
|
|
286
|
-
zeros =
|
|
286
|
+
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
|
287
287
|
comp = self.less(value, zeros)
|
|
288
288
|
return self.select(comp, zeros, cdf)
|
|
289
289
|
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Gumbel Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore import _checkparam as Validator
|
|
19
20
|
from mindspore.common import dtype as mstype
|
|
20
21
|
import mindspore.nn.probability.bijector as msb
|
|
@@ -101,7 +102,6 @@ class Gumbel(TransformedDistribution):
|
|
|
101
102
|
self.const = P.ScalarToTensor()
|
|
102
103
|
self.exp = exp_generic
|
|
103
104
|
self.expm1 = P.Expm1()
|
|
104
|
-
self.fill = P.Fill()
|
|
105
105
|
self.lgamma = P.Lgamma()
|
|
106
106
|
self.log = log_generic
|
|
107
107
|
self.shape = P.Shape()
|
|
@@ -163,7 +163,7 @@ class Gumbel(TransformedDistribution):
|
|
|
163
163
|
"""
|
|
164
164
|
The mode of the distribution.
|
|
165
165
|
"""
|
|
166
|
-
return self.loc *
|
|
166
|
+
return self.loc * F.fill(self.parameter_type, self.shape(self.scale), 1.0)
|
|
167
167
|
|
|
168
168
|
def _sd(self):
|
|
169
169
|
r"""
|
|
@@ -173,7 +173,7 @@ class Gumbel(TransformedDistribution):
|
|
|
173
173
|
STD(X) = \frac{\pi}{\sqrt(6)} * scale
|
|
174
174
|
"""
|
|
175
175
|
scale = self.scale * \
|
|
176
|
-
|
|
176
|
+
F.fill(self.parameter_type, self.broadcast_shape, 1.0)
|
|
177
177
|
return scale * np.pi / self.sqrt(self.const(6., mstype.float32))
|
|
178
178
|
|
|
179
179
|
def _entropy(self):
|
|
@@ -184,7 +184,7 @@ class Gumbel(TransformedDistribution):
|
|
|
184
184
|
H(X) = 1. + \log(scale) + Euler-Mascheroni_constant
|
|
185
185
|
"""
|
|
186
186
|
scale = self.scale * \
|
|
187
|
-
|
|
187
|
+
F.fill(self.parameter_type, self.broadcast_shape, 1.0)
|
|
188
188
|
return 1. + self.log(scale) + np.euler_gamma
|
|
189
189
|
|
|
190
190
|
def _log_prob(self, value):
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""LogNormal Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.common import dtype as mstype
|
|
19
20
|
import mindspore.nn.probability.bijector as msb
|
|
20
21
|
import mindspore.nn.probability.distribution as msd
|
|
@@ -101,7 +102,6 @@ class LogNormal(msd.TransformedDistribution):
|
|
|
101
102
|
self.expm1 = P.Expm1()
|
|
102
103
|
self.log = log_generic
|
|
103
104
|
self.erf = P.Erf()
|
|
104
|
-
self.fill = P.Fill()
|
|
105
105
|
self.greater = P.Greater()
|
|
106
106
|
self.select = P.Select()
|
|
107
107
|
self.shape = P.Shape()
|
|
@@ -202,7 +202,7 @@ class LogNormal(msd.TransformedDistribution):
|
|
|
202
202
|
cdf = self.distribution("cdf", inverse_value, mean, sd)
|
|
203
203
|
|
|
204
204
|
# to increase numerical stability, set cdf = 0 when value <= 0
|
|
205
|
-
zeros =
|
|
205
|
+
zeros = F.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
|
206
206
|
|
|
207
207
|
return self.select(self.greater(value, 0.), cdf, zeros)
|
|
208
208
|
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Logistic Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
from mindspore import _checkparam as Validator
|
|
20
21
|
from mindspore.common import dtype as mstype
|
|
@@ -153,7 +154,6 @@ class Logistic(Distribution):
|
|
|
153
154
|
self.dtypeop = P.DType()
|
|
154
155
|
self.exp = exp_generic
|
|
155
156
|
self.expm1 = P.Expm1()
|
|
156
|
-
self.fill = P.Fill()
|
|
157
157
|
self.less = P.Less()
|
|
158
158
|
self.log = log_generic
|
|
159
159
|
self.log1p = P.Log1p()
|
|
@@ -179,7 +179,7 @@ class Logistic(Distribution):
|
|
|
179
179
|
too_small_value = self.exp(x)
|
|
180
180
|
too_large_value = x
|
|
181
181
|
too_small_or_too_large = self.logicalor(too_small, too_large)
|
|
182
|
-
ones =
|
|
182
|
+
ones = F.fill(self.dtypeop(x), self.shape(x), 1.0)
|
|
183
183
|
x = self.select(too_small_or_too_large, ones, x)
|
|
184
184
|
y = self.log(self.exp(x) + 1.0)
|
|
185
185
|
return self.select(too_small, too_small_value, self.select(too_large, too_large_value, y))
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Poisson Distribution"""
|
|
16
16
|
import numpy as np
|
|
17
17
|
from mindspore.ops import operations as P
|
|
18
|
+
from mindspore.ops import functional as F
|
|
18
19
|
from mindspore.ops import composite as C
|
|
19
20
|
from mindspore import _checkparam as Validator
|
|
20
21
|
from mindspore.common import dtype as mstype
|
|
@@ -149,7 +150,6 @@ class Poisson(Distribution):
|
|
|
149
150
|
self.floor = P.Floor()
|
|
150
151
|
self.dtypeop = P.DType()
|
|
151
152
|
self.shape = P.Shape()
|
|
152
|
-
self.fill = P.Fill()
|
|
153
153
|
self.less = P.Less()
|
|
154
154
|
self.equal = P.Equal()
|
|
155
155
|
self.select = P.Select()
|
|
@@ -228,8 +228,8 @@ class Poisson(Distribution):
|
|
|
228
228
|
value = self.cast(value, self.dtype)
|
|
229
229
|
rate = self._check_param_type(rate)
|
|
230
230
|
log_rate = self.log(rate)
|
|
231
|
-
zeros =
|
|
232
|
-
inf =
|
|
231
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
232
|
+
inf = F.fill(self.dtypeop(value), self.shape(value), np.inf)
|
|
233
233
|
safe_x = self.select(self.less(value, zeros), zeros, value)
|
|
234
234
|
y = log_rate * safe_x - self.lgamma(safe_x + 1.)
|
|
235
235
|
comp = self.equal(value, safe_x)
|
|
@@ -254,7 +254,7 @@ class Poisson(Distribution):
|
|
|
254
254
|
value = self._check_value(value, 'value')
|
|
255
255
|
value = self.cast(value, self.dtype)
|
|
256
256
|
rate = self._check_param_type(rate)
|
|
257
|
-
zeros =
|
|
257
|
+
zeros = F.fill(self.dtypeop(value), self.shape(value), 0.0)
|
|
258
258
|
comp = self.less(value, zeros)
|
|
259
259
|
safe_x = self.select(comp, zeros, value)
|
|
260
260
|
cdf = 1. - self.igamma(1. + safe_x, rate)
|