mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/optim/rprop.py
CHANGED
|
@@ -83,7 +83,7 @@ class Rprop(Optimizer):
|
|
|
83
83
|
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
|
|
84
84
|
one group of `params`.
|
|
85
85
|
|
|
86
|
-
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Learning_rate. Default: 0.1.
|
|
86
|
+
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Learning_rate. Default: ``0.1`` .
|
|
87
87
|
|
|
88
88
|
- float: The fixed learning rate value. Must be equal to or greater than 0.
|
|
89
89
|
|
|
@@ -98,10 +98,10 @@ class Rprop(Optimizer):
|
|
|
98
98
|
LearningRateSchedule with step as the input to get the learning rate of current step.
|
|
99
99
|
|
|
100
100
|
etas (tuple[float, float]): The factor of multiplicative increasing or
|
|
101
|
-
descreasing(etaminus, etaplus). Default: (0.5, 1.2).
|
|
101
|
+
descreasing(etaminus, etaplus). Default: ``(0.5, 1.2)`` .
|
|
102
102
|
step_sizes(tuple[float, float]): The allowed minimal and maximal step size(min_step_sizes, max_step_size).
|
|
103
|
-
Default: (1e-6, 50.).
|
|
104
|
-
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: 0.0.
|
|
103
|
+
Default: ``(1e-6, 50.)`` .
|
|
104
|
+
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
|
|
105
105
|
|
|
106
106
|
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
|
107
107
|
|
|
@@ -134,7 +134,9 @@ class Rprop(Optimizer):
|
|
|
134
134
|
>>> import mindspore as ms
|
|
135
135
|
>>> from mindspore import nn
|
|
136
136
|
>>>
|
|
137
|
-
>>>
|
|
137
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
138
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
139
|
+
>>> net = LeNet5()
|
|
138
140
|
>>> #1) All parameters use the same learning rate and weight decay
|
|
139
141
|
>>> optim = nn.Rprop(params=net.trainable_params())
|
|
140
142
|
>>>
|
|
@@ -152,7 +154,7 @@ class Rprop(Optimizer):
|
|
|
152
154
|
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
|
153
155
|
>>>
|
|
154
156
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
155
|
-
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
|
|
157
|
+
>>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)
|
|
156
158
|
"""
|
|
157
159
|
|
|
158
160
|
@opt_init_args_register
|
|
@@ -187,8 +189,8 @@ class Rprop(Optimizer):
|
|
|
187
189
|
self.prev = self._parameters.clone(prefix="prev", init='zeros')
|
|
188
190
|
self.step_size = self._parameters.clone(prefix="step_size", init='zeros')
|
|
189
191
|
|
|
190
|
-
self.fill = P.Fill()
|
|
191
192
|
self.sign = P.Sign()
|
|
193
|
+
self.fill = P.FillV2()
|
|
192
194
|
self.assign = P.Assign()
|
|
193
195
|
self.assignadd = P.AssignAdd()
|
|
194
196
|
self.cast = P.Cast()
|
|
@@ -202,8 +204,7 @@ class Rprop(Optimizer):
|
|
|
202
204
|
gradients = self.gradients_centralization(gradients)
|
|
203
205
|
gradients = self.scale_grad(gradients)
|
|
204
206
|
lrs = self.get_lr()
|
|
205
|
-
|
|
206
|
-
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
207
|
+
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
207
208
|
success = True
|
|
208
209
|
|
|
209
210
|
for index, (grad, param, prev, step_size) in enumerate(zip(gradients, self._parameters,
|
|
@@ -219,14 +220,26 @@ class Rprop(Optimizer):
|
|
|
219
220
|
param_fp32 = self.cast(param, mstype.float32)
|
|
220
221
|
|
|
221
222
|
sign = self.sign(gradient_fp32 * prev)
|
|
222
|
-
sign = self.select(
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
223
|
+
sign = self.select(
|
|
224
|
+
sign > 0,
|
|
225
|
+
self.fill(sign.shape, self.cast(self.etaplus, mstype.float32)),
|
|
226
|
+
sign)
|
|
227
|
+
sign = self.select(
|
|
228
|
+
sign < 0,
|
|
229
|
+
self.fill(sign.shape, self.cast(self.etaminus,
|
|
230
|
+
mstype.float32)), sign)
|
|
231
|
+
sign = self.select(
|
|
232
|
+
sign == 0, self.fill(sign.shape,
|
|
233
|
+
self.cast(1., mstype.float32)), sign)
|
|
234
|
+
|
|
235
|
+
step_size_fp32 = ops.clip_by_value(step_size_fp32 * sign,
|
|
236
|
+
self.step_size_min,
|
|
237
|
+
self.step_size_max)
|
|
238
|
+
|
|
239
|
+
gradient_update = self.select(
|
|
240
|
+
sign == self.etaminus,
|
|
241
|
+
self.fill(sign.shape, self.cast(0., mstype.float32)),
|
|
242
|
+
gradient_fp32)
|
|
230
243
|
next_param = param_fp32 - self.sign(gradient_update) * step_size_fp32
|
|
231
244
|
|
|
232
245
|
self.assign(param, self.cast(next_param, param.dtype))
|
mindspore/nn/optim/sgd.py
CHANGED
|
@@ -44,17 +44,17 @@ class SGD(Optimizer):
|
|
|
44
44
|
momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
|
|
45
45
|
|
|
46
46
|
.. math::
|
|
47
|
-
|
|
47
|
+
v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
|
|
48
48
|
|
|
49
49
|
If nesterov is True:
|
|
50
50
|
|
|
51
51
|
.. math::
|
|
52
|
-
|
|
52
|
+
p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
|
|
53
53
|
|
|
54
54
|
If nesterov is False:
|
|
55
55
|
|
|
56
56
|
.. math::
|
|
57
|
-
|
|
57
|
+
p_{t+1} = p_{t} - lr \ast v_{t+1}
|
|
58
58
|
|
|
59
59
|
To be noticed, for the first step, :math:`v_{t+1} = gradient`.
|
|
60
60
|
|
|
@@ -90,7 +90,7 @@ class SGD(Optimizer):
|
|
|
90
90
|
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
|
|
91
91
|
one group of `params`.
|
|
92
92
|
|
|
93
|
-
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: 0.1.
|
|
93
|
+
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.1`` .
|
|
94
94
|
|
|
95
95
|
- float: The fixed learning rate value. Must be equal to or greater than 0.
|
|
96
96
|
|
|
@@ -104,22 +104,22 @@ class SGD(Optimizer):
|
|
|
104
104
|
- LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
|
|
105
105
|
LearningRateSchedule with step as the input to get the learning rate of current step.
|
|
106
106
|
|
|
107
|
-
momentum (float): A floating point value the momentum. must be at least 0.0. Default: 0.0.
|
|
108
|
-
dampening (float): A floating point value of dampening for momentum. must be at least 0.0. Default: 0.0.
|
|
109
|
-
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
|
107
|
+
momentum (float): A floating point value the momentum. must be at least 0.0. Default: ``0.0`` .
|
|
108
|
+
dampening (float): A floating point value of dampening for momentum. must be at least 0.0. Default: ``0.0`` .
|
|
109
|
+
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: ``0.0`` .
|
|
110
110
|
nesterov (bool): Enables the Nesterov momentum. If use nesterov, momentum must be positive,
|
|
111
|
-
and dampening must be equal to 0.0. Default: False.
|
|
111
|
+
and dampening must be equal to 0.0. Default: ``False`` .
|
|
112
112
|
loss_scale (float): A floating point value for the loss scale, which must be larger than 0.0. In general, use
|
|
113
113
|
the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
|
|
114
|
-
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
|
114
|
+
`FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
|
|
115
115
|
`FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
|
|
116
|
-
Default: 1.0.
|
|
116
|
+
Default: ``1.0`` .
|
|
117
117
|
|
|
118
118
|
Inputs:
|
|
119
119
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
|
120
120
|
|
|
121
121
|
Outputs:
|
|
122
|
-
Tensor[bool], the value is True.
|
|
122
|
+
Tensor[bool], the value is ``True`` .
|
|
123
123
|
|
|
124
124
|
Raises:
|
|
125
125
|
ValueError: If the momentum, dampening or weight_decay value is less than 0.0.
|
|
@@ -131,7 +131,9 @@ class SGD(Optimizer):
|
|
|
131
131
|
>>> import mindspore as ms
|
|
132
132
|
>>> from mindspore import nn
|
|
133
133
|
>>>
|
|
134
|
-
>>>
|
|
134
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
135
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
136
|
+
>>> net = LeNet5()
|
|
135
137
|
>>> #1) All parameters use the same learning rate and weight decay
|
|
136
138
|
>>> optim = nn.SGD(params=net.trainable_params())
|
|
137
139
|
>>>
|
|
@@ -149,7 +151,7 @@ class SGD(Optimizer):
|
|
|
149
151
|
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
|
150
152
|
>>>
|
|
151
153
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
152
|
-
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
|
|
154
|
+
>>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)
|
|
153
155
|
"""
|
|
154
156
|
|
|
155
157
|
@opt_init_args_register
|
|
@@ -161,29 +163,29 @@ class SGD(Optimizer):
|
|
|
161
163
|
if isinstance(momentum, int):
|
|
162
164
|
momentum = float(momentum)
|
|
163
165
|
if not isinstance(momentum, float):
|
|
164
|
-
raise TypeError("For 'SGD', the argument 'momentum' must be float type, "
|
|
165
|
-
"but got {
|
|
166
|
+
raise TypeError(f"For 'SGD', the argument 'momentum' must be float type, "
|
|
167
|
+
f"but got {type(momentum)}.")
|
|
166
168
|
|
|
167
169
|
if isinstance(momentum, float) and momentum < 0.0:
|
|
168
|
-
raise ValueError("For 'SGD', the argument 'momentum' must be at least 0.0, "
|
|
169
|
-
"but got {}."
|
|
170
|
+
raise ValueError(f"For 'SGD', the argument 'momentum' must be at least 0.0, "
|
|
171
|
+
f"but got {momentum}.")
|
|
170
172
|
|
|
171
173
|
if isinstance(dampening, int):
|
|
172
174
|
dampening = float(dampening)
|
|
173
175
|
if not isinstance(dampening, float):
|
|
174
|
-
raise TypeError("For 'SGD', the argument 'dampening' must be float type, "
|
|
175
|
-
"but got {
|
|
176
|
+
raise TypeError(f"For 'SGD', the argument 'dampening' must be float type, "
|
|
177
|
+
f"but got {type(dampening)}.")
|
|
176
178
|
|
|
177
179
|
if dampening < 0.0:
|
|
178
|
-
raise ValueError("For 'SGD', the argument 'dampening' must be at least 0.0, "
|
|
179
|
-
"but got 'dampening' {}"
|
|
180
|
+
raise ValueError(f"For 'SGD', the argument 'dampening' must be at least 0.0, "
|
|
181
|
+
f"but got 'dampening' {dampening}")
|
|
180
182
|
self.dampening = dampening
|
|
181
183
|
|
|
182
184
|
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
|
|
183
185
|
|
|
184
186
|
if nesterov and (momentum <= 0.0 or dampening != 0.0):
|
|
185
|
-
raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
|
|
186
|
-
"equal to 0.0, but got 'momentum' {}, 'dampening' {}"
|
|
187
|
+
raise ValueError(f"For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
|
|
188
|
+
f"equal to 0.0, but got 'momentum' {momentum}, 'dampening' {dampening}.")
|
|
187
189
|
self.nesterov = nesterov
|
|
188
190
|
|
|
189
191
|
if self.dynamic_weight_decay:
|
|
@@ -196,9 +198,23 @@ class SGD(Optimizer):
|
|
|
196
198
|
self.opt = tuple([P.SGD(dampening, float(weight_decay), nesterov)] * len(self._parameters))
|
|
197
199
|
|
|
198
200
|
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
|
201
|
+
|
|
202
|
+
if not momentum > 0.0:
|
|
203
|
+
enable_cache_param_list = []
|
|
204
|
+
for param in self._parameters:
|
|
205
|
+
if param.cache_enable:
|
|
206
|
+
enable_cache_param_list.append(param)
|
|
207
|
+
param.cache_enable = False
|
|
208
|
+
|
|
199
209
|
self.accum = self._parameters.clone(prefix="accum", init='zeros')
|
|
200
210
|
self.stat = self._parameters.clone(prefix="stat", init='ones')
|
|
201
211
|
|
|
212
|
+
|
|
213
|
+
if not momentum > 0.0:
|
|
214
|
+
for param in enable_cache_param_list:
|
|
215
|
+
param.cache_enable = True
|
|
216
|
+
|
|
217
|
+
|
|
202
218
|
@jit
|
|
203
219
|
def construct(self, gradients):
|
|
204
220
|
params = self._parameters
|
|
@@ -208,6 +224,7 @@ class SGD(Optimizer):
|
|
|
208
224
|
gradients = self.gradients_centralization(gradients)
|
|
209
225
|
gradients = self.scale_grad(gradients)
|
|
210
226
|
lr = self.get_lr()
|
|
227
|
+
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
211
228
|
if self.is_group_lr:
|
|
212
229
|
success = self.hyper_map_reverse(F.partial(_sgd_opt, self.momentum),
|
|
213
230
|
lr, gradients, params, accum, stat, self.opt)
|
mindspore/nn/optim/thor.py
CHANGED
|
@@ -266,10 +266,10 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
266
266
|
\otimes\left(G_{i}^{(k)}+\lambda I\right)^{-1}\right) \nabla_{w_{i}} J^{(k)}
|
|
267
267
|
\end{array}
|
|
268
268
|
|
|
269
|
-
:math:`a_{i-1}` represents the input of i
|
|
270
|
-
:math:`D_{s_i}` represents the derivative of the loss function of the output of the i
|
|
269
|
+
:math:`a_{i-1}` represents the input of :math:`i`-th layer,and which is the activations of previous layer.
|
|
270
|
+
:math:`D_{s_i}` represents the derivative of the loss function of the output of the :math:`i`-th layer.
|
|
271
271
|
:math:`I` represents the identity matrix.
|
|
272
|
-
:math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the i
|
|
272
|
+
:math:`\lambda` represents :math:`damping`, :math:`g_i` represents gradients of the :math:`i`-th layer.
|
|
273
273
|
:math:`\otimes` represents Kronecker product, :math:`\gamma` represents 'learning rate'.
|
|
274
274
|
|
|
275
275
|
Note:
|
|
@@ -290,14 +290,15 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
290
290
|
|
|
291
291
|
momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.
|
|
292
292
|
|
|
293
|
-
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0.
|
|
293
|
+
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0.
|
|
294
|
+
Default: ``0.0`` .
|
|
294
295
|
|
|
295
296
|
loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
|
|
296
|
-
default value. Default: 1.0.
|
|
297
|
+
default value. Default: ``1.0`` .
|
|
297
298
|
|
|
298
|
-
batch_size (int): The size of a batch. Default: 32
|
|
299
|
+
batch_size (int): The size of a batch. Default: ``32`` .
|
|
299
300
|
|
|
300
|
-
use_nesterov (bool): Enable Nesterov momentum. Default: False.
|
|
301
|
+
use_nesterov (bool): Enable Nesterov momentum. Default: ``False`` .
|
|
301
302
|
|
|
302
303
|
decay_filter (function): A function to determine which layers the weight decay applied to. And it
|
|
303
304
|
only works when the weight_decay > 0. Default: lambda x: x.name not in []
|
|
@@ -305,13 +306,13 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
305
306
|
split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
|
|
306
307
|
computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
|
|
307
308
|
to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other
|
|
308
|
-
is 27~53. Default: None
|
|
309
|
+
is 27~53. Default: ``None`` .
|
|
309
310
|
|
|
310
|
-
enable_clip_grad (bool): Whether to clip the gradients. Default: False
|
|
311
|
+
enable_clip_grad (bool): Whether to clip the gradients. Default: ``False`` .
|
|
311
312
|
|
|
312
313
|
frequency(int): The update interval of A/G and :math:`A^{-1}/G^{-1}`. When frequency equals N
|
|
313
314
|
(N is greater than 1), A/G and :math:`A^{-1}/G^{-1}` will be updated every N steps,
|
|
314
|
-
and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: 100.
|
|
315
|
+
and other steps will use the stale A/G and :math:`A^{-1}/G^{-1}` to update weights. Default: ``100`` .
|
|
315
316
|
|
|
316
317
|
Inputs:
|
|
317
318
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
|
@@ -333,21 +334,18 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
333
334
|
``Ascend`` ``GPU``
|
|
334
335
|
|
|
335
336
|
Examples:
|
|
336
|
-
.. note::
|
|
337
|
-
Before running the following example, you need to customize the network Net and
|
|
338
|
-
dataset preparation function create_dataset. Refer to
|
|
339
|
-
`Building a Network <https://www.mindspore.cn/tutorials/en/r2.0/beginner/model.html>`_
|
|
340
|
-
and `Dataset <https://www.mindspore.cn/tutorials/en/r2.0/beginner/dataset.html>`_ .
|
|
341
|
-
|
|
342
337
|
>>> import mindspore as ms
|
|
343
|
-
>>> from mindspore.nn import thor
|
|
344
338
|
>>> from mindspore import nn
|
|
345
339
|
>>> from mindspore import Tensor
|
|
346
340
|
>>>
|
|
347
|
-
>>>
|
|
341
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
342
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
343
|
+
>>> net = LeNet5()
|
|
344
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
345
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
348
346
|
>>> dataset = create_dataset()
|
|
349
347
|
>>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
|
|
350
|
-
>>> optim = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
|
|
348
|
+
>>> optim = nn.thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
|
|
351
349
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
352
350
|
>>> loss_scale = ms.FixedLossScaleManager(128, drop_overflow_update=False)
|
|
353
351
|
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'},
|
|
@@ -355,8 +353,6 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|
|
355
353
|
>>> model = ms.ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim,
|
|
356
354
|
... loss_scale_manager=loss_scale, metrics={'acc'},
|
|
357
355
|
... amp_level="O2", keep_batchnorm_fp32=False)
|
|
358
|
-
>>> loss_cb = ms.LossMonitor()
|
|
359
|
-
>>> model.train(1, dataset, callbacks=loss_cb, sink_size=4, dataset_sink_mode=True)
|
|
360
356
|
|
|
361
357
|
"""
|
|
362
358
|
context.set_context(max_call_depth=10000)
|
|
@@ -428,7 +424,7 @@ class ThorGpu(Optimizer):
|
|
|
428
424
|
self.matmul = P.MatMul()
|
|
429
425
|
self.assign = P.Assign()
|
|
430
426
|
self.mul = P.Mul()
|
|
431
|
-
self.gather = P.
|
|
427
|
+
self.gather = P.Gather()
|
|
432
428
|
self.one = Tensor(1, mstype.int32)
|
|
433
429
|
self.feature_map = Tensor(1.0, mstype.float32)
|
|
434
430
|
self.axis = 0
|
|
@@ -657,6 +653,7 @@ class ThorGpu(Optimizer):
|
|
|
657
653
|
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
|
|
658
654
|
gradients = clip_gradient(self.enable_clip_grad, gradients)
|
|
659
655
|
lr = self.get_lr()
|
|
656
|
+
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
660
657
|
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
|
661
658
|
return success
|
|
662
659
|
|
|
@@ -681,7 +678,7 @@ class ThorAscend(Optimizer):
|
|
|
681
678
|
self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
|
|
682
679
|
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
|
|
683
680
|
self._define_ascend_operator()
|
|
684
|
-
self.
|
|
681
|
+
self.c0 = 16
|
|
685
682
|
self.device_shape_pad_flag = ()
|
|
686
683
|
self.diag_block_dim = 128
|
|
687
684
|
self.matrix_a = ()
|
|
@@ -743,7 +740,7 @@ class ThorAscend(Optimizer):
|
|
|
743
740
|
self.log = P.Log()
|
|
744
741
|
self.exp = P.Exp()
|
|
745
742
|
self.sqrt = P.Sqrt()
|
|
746
|
-
self.gather = P.
|
|
743
|
+
self.gather = P.Gather()
|
|
747
744
|
self.assign = P.Assign()
|
|
748
745
|
self.cast = P.Cast()
|
|
749
746
|
self.eye = P.Eye()
|
|
@@ -989,8 +986,8 @@ class ThorAscend(Optimizer):
|
|
|
989
986
|
kernel_hw = weight_shape[2] * weight_shape[3]
|
|
990
987
|
in_channels = weight_shape[1]
|
|
991
988
|
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
|
|
992
|
-
matrix_a_inv = P.Pad(((0, 0), (0, self.
|
|
993
|
-
(0, self.
|
|
989
|
+
matrix_a_inv = P.Pad(((0, 0), (0, self.c0 - in_channels), (0, 0),
|
|
990
|
+
(0, self.c0 - in_channels)))(matrix_a_inv)
|
|
994
991
|
return matrix_a_inv
|
|
995
992
|
|
|
996
993
|
def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
|
|
@@ -1308,5 +1305,6 @@ class ThorAscend(Optimizer):
|
|
|
1308
1305
|
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
|
|
1309
1306
|
gradients = clip_gradient(self.enable_clip_grad, gradients)
|
|
1310
1307
|
lr = self.get_lr()
|
|
1308
|
+
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
1311
1309
|
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
|
1312
1310
|
return success
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from mindspore import context
|
|
17
17
|
from mindspore.nn.cell import Cell
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
|
+
from mindspore.ops import functional as F
|
|
19
20
|
from mindspore.ops.operations import _inner_ops as inner
|
|
20
21
|
from mindspore.common import dtype as mstype
|
|
21
22
|
from mindspore.common.tensor import Tensor
|
|
@@ -33,11 +34,11 @@ class Bijector(Cell):
|
|
|
33
34
|
then :math:`Y = g(X)` is the random variable following the transformed distribution.
|
|
34
35
|
|
|
35
36
|
Args:
|
|
36
|
-
is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: False.
|
|
37
|
-
is_injective (bool): Whether the Bijector is a one-to-one mapping. Default: True.
|
|
38
|
-
name (str): The name of the Bijector. Default: None.
|
|
39
|
-
dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: None.
|
|
40
|
-
param (dict): The parameters used to initialize the Bijector. Default: None.
|
|
37
|
+
is_constant_jacobian (bool): Whether the Bijector has constant derivative. Default: ``False`` .
|
|
38
|
+
is_injective (bool): Whether the Bijector is a one-to-one mapping. Default: ``True`` .
|
|
39
|
+
name (str): The name of the Bijector. Default: ``None`` .
|
|
40
|
+
dtype (mindspore.dtype): The type of the distributions that the Bijector can operate on. Default: ``None`` .
|
|
41
|
+
param (dict): The parameters used to initialize the Bijector. Default: ``None`` .
|
|
41
42
|
|
|
42
43
|
Note:
|
|
43
44
|
`dtype` of bijector represents the type of the distributions that the bijector could operate on.
|
|
@@ -96,7 +97,6 @@ class Bijector(Cell):
|
|
|
96
97
|
self.cast_base = P.Cast()
|
|
97
98
|
self.dtype_base = P.DType()
|
|
98
99
|
self.shape_base = P.Shape()
|
|
99
|
-
self.fill_base = P.Fill()
|
|
100
100
|
self.sametypeshape_base = inner.SameTypeShape()
|
|
101
101
|
self.issubclass_base = inner.IsSubClass()
|
|
102
102
|
|
|
@@ -140,13 +140,13 @@ class Bijector(Cell):
|
|
|
140
140
|
if self.issubclass_base(value_type, mstype.float_):
|
|
141
141
|
return value
|
|
142
142
|
return raise_type_error('input value of bijector', value_type, mstype.float_)
|
|
143
|
-
dtype_tensor =
|
|
143
|
+
dtype_tensor = F.fill(self.dtype, self.shape_base(value), 0.0)
|
|
144
144
|
self.sametypeshape_base(value, dtype_tensor)
|
|
145
145
|
return value
|
|
146
146
|
|
|
147
147
|
def _shape_mapping(self, shape):
|
|
148
|
-
shape_tensor =
|
|
149
|
-
dist_shape_tensor =
|
|
148
|
+
shape_tensor = F.fill(self.parameter_type, shape, 0.0)
|
|
149
|
+
dist_shape_tensor = F.fill(
|
|
150
150
|
self.parameter_type, self.batch_shape, 0.0)
|
|
151
151
|
return (shape_tensor + dist_shape_tensor).shape
|
|
152
152
|
|
|
@@ -165,7 +165,7 @@ class Bijector(Cell):
|
|
|
165
165
|
self.common_dtype = None
|
|
166
166
|
# cast value to a tensor if it is not None
|
|
167
167
|
if isinstance(value, bool) or value is None:
|
|
168
|
-
raise TypeError("{} cannot be type {
|
|
168
|
+
raise TypeError(f"{name} cannot be type {type(value)}")
|
|
169
169
|
value_t = Tensor(value)
|
|
170
170
|
# if the bijector's dtype is not specified
|
|
171
171
|
if self.dtype is None:
|
|
@@ -189,7 +189,7 @@ class Bijector(Cell):
|
|
|
189
189
|
"""
|
|
190
190
|
Calculate batch_shape based on parameters.
|
|
191
191
|
"""
|
|
192
|
-
if 'param_dict' not in self.parameters
|
|
192
|
+
if 'param_dict' not in self.parameters:
|
|
193
193
|
return None
|
|
194
194
|
param_dict = self.parameters.get('param_dict')
|
|
195
195
|
broadcast_shape_tensor = None
|
|
@@ -28,9 +28,9 @@ class GumbelCDF(Bijector):
|
|
|
28
28
|
Y = \exp(-\exp(\frac{-(X - loc)}{scale}))
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
|
-
loc (float, list, numpy.ndarray, Tensor): The location. Default: 0.0.
|
|
32
|
-
scale (float, list, numpy.ndarray, Tensor): The scale. Default: 1.0.
|
|
33
|
-
name (str): The name of the Bijector. Default: 'GumbelCDF'.
|
|
31
|
+
loc (float, list, numpy.ndarray, Tensor): The location. Default: ``0.0`` .
|
|
32
|
+
scale (float, list, numpy.ndarray, Tensor): The scale. Default: ``1.0`` .
|
|
33
|
+
name (str): The name of the Bijector. Default: ``'GumbelCDF'`` .
|
|
34
34
|
|
|
35
35
|
Note:
|
|
36
36
|
`scale` must be greater than zero.
|
|
@@ -25,7 +25,7 @@ class Invert(Bijector):
|
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
27
|
bijector (Bijector): Base Bijector.
|
|
28
|
-
name (str): The name of the Bijector. Default: "". When name is set to "", it is actually
|
|
28
|
+
name (str): The name of the Bijector. Default: ``""`` . When name is set to "", it is actually
|
|
29
29
|
'Invert' + bijector.name.
|
|
30
30
|
|
|
31
31
|
Supported Platforms:
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""PowerTransform Bijector"""
|
|
16
16
|
from mindspore.ops import operations as P
|
|
17
|
+
from mindspore.ops import functional as F
|
|
17
18
|
from ..distribution._utils.utils import check_greater_equal_zero
|
|
18
19
|
from ..distribution._utils.custom_ops import exp_generic, log_generic
|
|
19
20
|
from .bijector import Bijector
|
|
@@ -34,8 +35,8 @@ class PowerTransform(Bijector):
|
|
|
34
35
|
This Bijector is equivalent to the :class:`mindspore.nn.probability.bijector.Exp` bijector when `c=0`.
|
|
35
36
|
|
|
36
37
|
Args:
|
|
37
|
-
power (float, list, numpy.ndarray, Tensor): The scale factor. Default: 0.
|
|
38
|
-
name (str): The name of the bijector. Default: 'PowerTransform'.
|
|
38
|
+
power (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``0`` .
|
|
39
|
+
name (str): The name of the bijector. Default: ``'PowerTransform'`` .
|
|
39
40
|
|
|
40
41
|
Note:
|
|
41
42
|
The dtype of `power` must be float.
|
|
@@ -68,10 +69,7 @@ class PowerTransform(Bijector):
|
|
|
68
69
|
>>> print(ans4.shape)
|
|
69
70
|
(3,)
|
|
70
71
|
"""
|
|
71
|
-
|
|
72
|
-
def __init__(self,
|
|
73
|
-
power=0.,
|
|
74
|
-
name='PowerTransform'):
|
|
72
|
+
def __init__(self, power=0., name='PowerTransform'):
|
|
75
73
|
param = dict(locals())
|
|
76
74
|
param['param_dict'] = {'power': power}
|
|
77
75
|
super(PowerTransform, self).__init__(name=name, param=param)
|
|
@@ -84,7 +82,6 @@ class PowerTransform(Bijector):
|
|
|
84
82
|
self.equal_base = P.Equal()
|
|
85
83
|
self.exp = exp_generic
|
|
86
84
|
self.expm1 = P.Expm1()
|
|
87
|
-
self.fill = P.Fill()
|
|
88
85
|
self.log = log_generic
|
|
89
86
|
self.log1p = P.Log1p()
|
|
90
87
|
self.select_base = P.Select()
|
|
@@ -116,17 +113,18 @@ class PowerTransform(Bijector):
|
|
|
116
113
|
power_local = self.cast_param_by_value(x, self.power)
|
|
117
114
|
|
|
118
115
|
# broad cast the value of x and power
|
|
119
|
-
ones =
|
|
120
|
-
|
|
116
|
+
ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
|
|
117
|
+
1.)
|
|
121
118
|
power_local = power_local * ones
|
|
122
119
|
x = x * ones
|
|
123
|
-
safe_power = self.select_base(
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
forward_v = self.select_base(
|
|
128
|
-
|
|
129
|
-
|
|
120
|
+
safe_power = self.select_base(
|
|
121
|
+
self.equal_base(power_local,
|
|
122
|
+
P.ZerosLike()(power_local)), ones, power_local)
|
|
123
|
+
|
|
124
|
+
forward_v = self.select_base(
|
|
125
|
+
self.equal_base(power_local,
|
|
126
|
+
P.ZerosLike()(power_local)), self.exp(x),
|
|
127
|
+
self.exp(self.log1p(x * safe_power) / safe_power))
|
|
130
128
|
return forward_v
|
|
131
129
|
|
|
132
130
|
def _inverse(self, y):
|
|
@@ -137,17 +135,18 @@ class PowerTransform(Bijector):
|
|
|
137
135
|
power_local = self.cast_param_by_value(y, self.power)
|
|
138
136
|
|
|
139
137
|
# broad cast the value of x and power
|
|
140
|
-
ones =
|
|
141
|
-
|
|
138
|
+
ones = F.fill(self.dtypeop(power_local), self.shape(y + power_local),
|
|
139
|
+
1.)
|
|
142
140
|
power_local = power_local * ones
|
|
143
141
|
y = y * ones
|
|
144
|
-
safe_power = self.select_base(
|
|
145
|
-
|
|
146
|
-
|
|
142
|
+
safe_power = self.select_base(
|
|
143
|
+
self.equal_base(power_local,
|
|
144
|
+
P.ZerosLike()(power_local)), ones, power_local)
|
|
147
145
|
|
|
148
|
-
inverse_v = self.select_base(
|
|
149
|
-
|
|
150
|
-
|
|
146
|
+
inverse_v = self.select_base(
|
|
147
|
+
self.equal_base(power_local,
|
|
148
|
+
P.ZerosLike()(power_local)), self.log(y),
|
|
149
|
+
self.expm1(self.log(y) * safe_power) / safe_power)
|
|
151
150
|
|
|
152
151
|
return inverse_v
|
|
153
152
|
|
|
@@ -167,14 +166,15 @@ class PowerTransform(Bijector):
|
|
|
167
166
|
power_local = self.cast_param_by_value(x, self.power)
|
|
168
167
|
|
|
169
168
|
# broad cast the value of x and power
|
|
170
|
-
ones =
|
|
171
|
-
|
|
169
|
+
ones = F.fill(self.dtypeop(power_local), self.shape(x + power_local),
|
|
170
|
+
1.)
|
|
172
171
|
power_local = power_local * ones
|
|
173
172
|
x = x * ones
|
|
174
173
|
|
|
175
|
-
forward_log_j = self.select_base(
|
|
176
|
-
|
|
177
|
-
|
|
174
|
+
forward_log_j = self.select_base(
|
|
175
|
+
self.equal_base(power_local,
|
|
176
|
+
P.ZerosLike()(power_local)), x,
|
|
177
|
+
(1. / power_local - 1) * self.log1p(x * power_local))
|
|
178
178
|
|
|
179
179
|
return forward_log_j
|
|
180
180
|
|
|
@@ -29,9 +29,9 @@ class ScalarAffine(Bijector):
|
|
|
29
29
|
where a is the scale factor and b is the shift factor.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
32
|
-
scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: 1.0.
|
|
33
|
-
shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: 0.0.
|
|
34
|
-
name (str): The name of the bijector. Default: 'ScalarAffine'.
|
|
32
|
+
scale (float, list, numpy.ndarray, Tensor): The scale factor. Default: ``1.0`` .
|
|
33
|
+
shift (float, list, numpy.ndarray, Tensor): The shift factor. Default: ``0.0`` .
|
|
34
|
+
name (str): The name of the bijector. Default: ``'ScalarAffine'`` .
|
|
35
35
|
|
|
36
36
|
Note:
|
|
37
37
|
The dtype of `shift` and `scale` must be float.
|