mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -23,7 +23,8 @@ from PIL import Image
|
|
|
23
23
|
import mindspore
|
|
24
24
|
import mindspore._c_dataengine as cde
|
|
25
25
|
|
|
26
|
-
|
|
26
|
+
# The following constants have been deprecated by Pillow since version 9.1.0
|
|
27
|
+
if int(Image.__version__.split(".")[0]) > 9 or Image.__version__ >= "9.1.0":
|
|
27
28
|
FLIP_LEFT_RIGHT = Image.Transpose.FLIP_LEFT_RIGHT
|
|
28
29
|
FLIP_TOP_BOTTOM = Image.Transpose.FLIP_TOP_BOTTOM
|
|
29
30
|
PERSPECTIVE = Image.Transform.PERSPECTIVE
|
|
@@ -47,14 +48,14 @@ class AutoAugmentPolicy(str, Enum):
|
|
|
47
48
|
"""
|
|
48
49
|
AutoAugment policy for different datasets.
|
|
49
50
|
|
|
50
|
-
Possible enumeration values are: AutoAugmentPolicy.IMAGENET
|
|
51
|
+
Possible enumeration values are: ``AutoAugmentPolicy.IMAGENET``, ``AutoAugmentPolicy.CIFAR10``,
|
|
51
52
|
AutoAugmentPolicy.SVHN.
|
|
52
53
|
|
|
53
54
|
Each policy contains 25 pairs of augmentation operations. When using AutoAugment, each image is randomly
|
|
54
55
|
transformed with one of these operation pairs. Each pair has 2 different operations. The following shows
|
|
55
56
|
all of these augmentation operations, including operation names with their probabilities and random params.
|
|
56
57
|
|
|
57
|
-
- AutoAugmentPolicy.IMAGENET
|
|
58
|
+
- ``AutoAugmentPolicy.IMAGENET``: dataset auto augment policy for ImageNet.
|
|
58
59
|
|
|
59
60
|
.. code-block::
|
|
60
61
|
|
|
@@ -73,7 +74,7 @@ class AutoAugmentPolicy(str, Enum):
|
|
|
73
74
|
(("Invert", 0.6, None), ("Equalize", 1.0, None)), (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
|
74
75
|
(("Equalize", 0.8, None), ("Equalize", 0.6, None))]
|
|
75
76
|
|
|
76
|
-
- AutoAugmentPolicy.CIFAR10
|
|
77
|
+
- ``AutoAugmentPolicy.CIFAR10``: dataset auto augment policy for Cifar10.
|
|
77
78
|
|
|
78
79
|
.. code-block::
|
|
79
80
|
|
|
@@ -94,7 +95,7 @@ class AutoAugmentPolicy(str, Enum):
|
|
|
94
95
|
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
|
|
95
96
|
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None))]
|
|
96
97
|
|
|
97
|
-
- AutoAugmentPolicy.SVHN
|
|
98
|
+
- ``AutoAugmentPolicy.SVHN``: dataset auto augment policy for SVHN.
|
|
98
99
|
|
|
99
100
|
.. code-block::
|
|
100
101
|
|
|
@@ -133,13 +134,13 @@ class Border(str, Enum):
|
|
|
133
134
|
"""
|
|
134
135
|
Padding Mode, Border Type.
|
|
135
136
|
|
|
136
|
-
Possible enumeration values are: Border.CONSTANT
|
|
137
|
+
Possible enumeration values are: ``Border.CONSTANT``, ``Border.EDGE``, ``Border.REFLECT``, ``Border.SYMMETRIC``.
|
|
137
138
|
|
|
138
|
-
- Border.CONSTANT: means it fills the border with constant values.
|
|
139
|
-
- Border.EDGE: means it pads with the last value on the edge.
|
|
140
|
-
- Border.REFLECT: means it reflects the values on the edge omitting the last value of edge.
|
|
139
|
+
- ``Border.CONSTANT`` : means it fills the border with constant values.
|
|
140
|
+
- ``Border.EDGE`` : means it pads with the last value on the edge.
|
|
141
|
+
- ``Border.REFLECT`` : means it reflects the values on the edge omitting the last value of edge.
|
|
141
142
|
For example, padding [1,2,3,4] with 2 elements on both sides will result in [3,2,1,2,3,4,3,2].
|
|
142
|
-
- Border.SYMMETRIC: means it reflects the values on the edge repeating the last value of edge.
|
|
143
|
+
- ``Border.SYMMETRIC`` : means it reflects the values on the edge repeating the last value of edge.
|
|
143
144
|
For example, padding [1,2,3,4] with 2 elements on both sides will result in [2,1,1,2,3,4,4,3].
|
|
144
145
|
|
|
145
146
|
Note:
|
|
@@ -256,10 +257,10 @@ class ImageBatchFormat(IntEnum):
|
|
|
256
257
|
"""
|
|
257
258
|
Data Format of images after batch operation.
|
|
258
259
|
|
|
259
|
-
Possible enumeration values are: ImageBatchFormat.NHWC
|
|
260
|
+
Possible enumeration values are: ``ImageBatchFormat.NHWC``, ``ImageBatchFormat.NCHW``.
|
|
260
261
|
|
|
261
|
-
- ImageBatchFormat.NHWC
|
|
262
|
-
- ImageBatchFormat.NCHW
|
|
262
|
+
- ``ImageBatchFormat.NHWC``: in orders like, batch N, height H, width W, channels C to store the data.
|
|
263
|
+
- ``ImageBatchFormat.NCHW``: in orders like, batch N, channels C, height H, width W to store the data.
|
|
263
264
|
"""
|
|
264
265
|
NHWC = 0
|
|
265
266
|
NCHW = 1
|
|
@@ -279,11 +280,11 @@ class ImageReadMode(IntEnum):
|
|
|
279
280
|
"""
|
|
280
281
|
The read mode used for the image file.
|
|
281
282
|
|
|
282
|
-
Possible enumeration values are: ImageReadMode.UNCHANGED
|
|
283
|
+
Possible enumeration values are: ``ImageReadMode.UNCHANGED``, ``ImageReadMode.GRAYSCALE``, ``ImageReadMode.COLOR``.
|
|
283
284
|
|
|
284
|
-
- ImageReadMode.UNCHANGED
|
|
285
|
-
- ImageReadMode.GRAYSCALE
|
|
286
|
-
- ImageReadMode.COLOR
|
|
285
|
+
- ``ImageReadMode.UNCHANGED``: remain the output in the original format.
|
|
286
|
+
- ``ImageReadMode.GRAYSCALE``: convert the output into one channel grayscale data.
|
|
287
|
+
- ``ImageReadMode.COLOR``: convert the output into three channels RGB color data.
|
|
287
288
|
"""
|
|
288
289
|
UNCHANGED = 0
|
|
289
290
|
GRAYSCALE = 1
|
|
@@ -302,20 +303,19 @@ class ImageReadMode(IntEnum):
|
|
|
302
303
|
|
|
303
304
|
class Inter(IntEnum):
|
|
304
305
|
"""
|
|
305
|
-
Interpolation
|
|
306
|
+
Interpolation methods.
|
|
306
307
|
|
|
307
|
-
|
|
308
|
-
Inter.BICUBIC, Inter.AREA, Inter.PILCUBIC.
|
|
308
|
+
Available values are as follows:
|
|
309
309
|
|
|
310
|
-
- Inter.NEAREST:
|
|
311
|
-
- Inter.ANTIALIAS:
|
|
312
|
-
- Inter.LINEAR:
|
|
313
|
-
- Inter.BILINEAR:
|
|
314
|
-
- Inter.CUBIC:
|
|
315
|
-
- Inter.BICUBIC:
|
|
316
|
-
- Inter.AREA:
|
|
317
|
-
- Inter.PILCUBIC:
|
|
318
|
-
|
|
310
|
+
- ``Inter.NEAREST`` : Nearest neighbor interpolation.
|
|
311
|
+
- ``Inter.ANTIALIAS`` : Antialias interpolation. Supported only when the input is PIL.Image.Image.
|
|
312
|
+
- ``Inter.LINEAR`` : Linear interpolation, the same as ``Inter.BILINEAR``.
|
|
313
|
+
- ``Inter.BILINEAR`` : Bilinear interpolation.
|
|
314
|
+
- ``Inter.CUBIC`` : Cubic interpolation, the same as ``Inter.BICUBIC``.
|
|
315
|
+
- ``Inter.BICUBIC`` : Bicubic interpolation.
|
|
316
|
+
- ``Inter.AREA`` : Pixel area interpolation. Supported only when the input is numpy.ndarray.
|
|
317
|
+
- ``Inter.PILCUBIC`` : Pillow implementation of bicubic interpolation. Supported only when the input
|
|
318
|
+
is numpy.ndarray.
|
|
319
319
|
"""
|
|
320
320
|
NEAREST = 0
|
|
321
321
|
ANTIALIAS = 1
|
|
@@ -354,10 +354,10 @@ class SliceMode(IntEnum):
|
|
|
354
354
|
"""
|
|
355
355
|
Mode to Slice Tensor into multiple parts.
|
|
356
356
|
|
|
357
|
-
Possible enumeration values are: SliceMode.PAD
|
|
357
|
+
Possible enumeration values are: ``SliceMode.PAD``, ``SliceMode.DROP``.
|
|
358
358
|
|
|
359
|
-
- SliceMode.PAD
|
|
360
|
-
- SliceMode.DROP
|
|
359
|
+
- ``SliceMode.PAD``: pad some pixels before slice the Tensor if needed.
|
|
360
|
+
- ``SliceMode.DROP``: drop remainder pixels before slice the Tensor if needed.
|
|
361
361
|
"""
|
|
362
362
|
PAD = 0
|
|
363
363
|
DROP = 1
|
|
@@ -379,7 +379,7 @@ def encode_jpeg(image, quality=75):
|
|
|
379
379
|
|
|
380
380
|
Args:
|
|
381
381
|
image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
|
|
382
|
-
quality (int, optional): Quality of the resulting JPEG data, in range of [1, 100]. Default: 75
|
|
382
|
+
quality (int, optional): Quality of the resulting JPEG data, in range of [1, 100]. Default: ``75``.
|
|
383
383
|
|
|
384
384
|
Returns:
|
|
385
385
|
numpy.ndarray, one dimension uint8 data.
|
|
@@ -395,6 +395,7 @@ def encode_jpeg(image, quality=75):
|
|
|
395
395
|
``CPU``
|
|
396
396
|
|
|
397
397
|
Examples:
|
|
398
|
+
>>> import mindspore.dataset.vision as vision
|
|
398
399
|
>>> import numpy as np
|
|
399
400
|
>>> # Generate a random image with height=120, width=340, channels=3
|
|
400
401
|
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
|
|
@@ -416,7 +417,8 @@ def encode_png(image, compression_level=6):
|
|
|
416
417
|
|
|
417
418
|
Args:
|
|
418
419
|
image (Union[numpy.ndarray, mindspore.Tensor]): The image to be encoded.
|
|
419
|
-
compression_level (int, optional): The compression_level for encoding, in range of [0, 9].
|
|
420
|
+
compression_level (int, optional): The `compression_level` for encoding, in range of [0, 9].
|
|
421
|
+
Default: ``6``.
|
|
420
422
|
|
|
421
423
|
Returns:
|
|
422
424
|
numpy.ndarray, one dimension uint8 data.
|
|
@@ -432,6 +434,7 @@ def encode_png(image, compression_level=6):
|
|
|
432
434
|
``CPU``
|
|
433
435
|
|
|
434
436
|
Examples:
|
|
437
|
+
>>> import mindspore.dataset.vision as vision
|
|
435
438
|
>>> import numpy as np
|
|
436
439
|
>>> # Generate a random image with height=120, width=340, channels=3
|
|
437
440
|
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
|
|
@@ -463,6 +466,9 @@ def get_image_num_channels(image):
|
|
|
463
466
|
TypeError: If `image` is not of type <class 'numpy.ndarray'> or <class 'PIL.Image.Image'>.
|
|
464
467
|
|
|
465
468
|
Examples:
|
|
469
|
+
>>> import mindspore.dataset.vision as vision
|
|
470
|
+
>>> from PIL import Image
|
|
471
|
+
>>> image = Image.open("/path/to/image_file")
|
|
466
472
|
>>> num_channels = vision.get_image_num_channels(image)
|
|
467
473
|
"""
|
|
468
474
|
|
|
@@ -493,6 +499,9 @@ def get_image_size(image):
|
|
|
493
499
|
TypeError: If `image` is not of type <class 'numpy.ndarray'> or <class 'PIL.Image.Image'>.
|
|
494
500
|
|
|
495
501
|
Examples:
|
|
502
|
+
>>> import mindspore.dataset.vision as vision
|
|
503
|
+
>>> from PIL import Image
|
|
504
|
+
>>> image = Image.open("/path/to/image_file")
|
|
496
505
|
>>> image_size = vision.get_image_size(image)
|
|
497
506
|
"""
|
|
498
507
|
|
|
@@ -538,6 +547,7 @@ def read_file(filename):
|
|
|
538
547
|
``CPU``
|
|
539
548
|
|
|
540
549
|
Examples:
|
|
550
|
+
>>> import mindspore.dataset.vision as vision
|
|
541
551
|
>>> output = vision.read_file("/path/to/file")
|
|
542
552
|
"""
|
|
543
553
|
if isinstance(filename, str):
|
|
@@ -552,8 +562,9 @@ def read_image(filename, mode=ImageReadMode.UNCHANGED):
|
|
|
552
562
|
|
|
553
563
|
Args:
|
|
554
564
|
filename(str): The path to the image file to be read.
|
|
555
|
-
mode(ImageReadMode, optional): The mode used for decoding the image. It can be
|
|
556
|
-
|
|
565
|
+
mode(ImageReadMode, optional): The mode used for decoding the image. It can be
|
|
566
|
+
``ImageReadMode.UNCHANGED``, ``ImageReadMode.GRAYSCALE``, ``IMageReadMode.COLOR``.
|
|
567
|
+
Default: ``ImageReadMode.UNCHANGED``.
|
|
557
568
|
|
|
558
569
|
- ImageReadMode.UNCHANGED, remain the output in the original format.
|
|
559
570
|
|
|
@@ -573,6 +584,7 @@ def read_image(filename, mode=ImageReadMode.UNCHANGED):
|
|
|
573
584
|
``CPU``
|
|
574
585
|
|
|
575
586
|
Examples:
|
|
587
|
+
>>> import mindspore.dataset.vision as vision
|
|
576
588
|
>>> from mindspore.dataset.vision import ImageReadMode
|
|
577
589
|
>>> output = vision.read_image("/path/to/image_file", ImageReadMode.UNCHANGED)
|
|
578
590
|
"""
|
|
@@ -602,6 +614,7 @@ def write_file(filename, data):
|
|
|
602
614
|
``CPU``
|
|
603
615
|
|
|
604
616
|
Examples:
|
|
617
|
+
>>> import mindspore.dataset.vision as vision
|
|
605
618
|
>>> import numpy as np
|
|
606
619
|
>>> # Generate a random data with 1024 bytes
|
|
607
620
|
>>> data = np.random.randint(256, size=(1024), dtype=np.uint8)
|
|
@@ -624,7 +637,7 @@ def write_jpeg(filename, image, quality=75):
|
|
|
624
637
|
Args:
|
|
625
638
|
filename (str): The path to the file to be written.
|
|
626
639
|
image (Union[numpy.ndarray, mindspore.Tensor]): The image data to be written.
|
|
627
|
-
quality (int, optional): Quality of the resulting JPEG file, in range of [1, 100]. Default: 75
|
|
640
|
+
quality (int, optional): Quality of the resulting JPEG file, in range of [1, 100]. Default: ``75``.
|
|
628
641
|
|
|
629
642
|
Raises:
|
|
630
643
|
TypeError: If `filename` is not of type str.
|
|
@@ -639,6 +652,7 @@ def write_jpeg(filename, image, quality=75):
|
|
|
639
652
|
``CPU``
|
|
640
653
|
|
|
641
654
|
Examples:
|
|
655
|
+
>>> import mindspore.dataset.vision as vision
|
|
642
656
|
>>> import numpy as np
|
|
643
657
|
>>> # Generate a random image with height=120, width=340, channels=3
|
|
644
658
|
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
|
|
@@ -663,7 +677,8 @@ def write_png(filename, image, compression_level=6):
|
|
|
663
677
|
Args:
|
|
664
678
|
filename (str): The path to the file to be written.
|
|
665
679
|
image (Union[numpy.ndarray, mindspore.Tensor]): The image data to be written.
|
|
666
|
-
compression_level (int, optional): Compression level for the resulting PNG file, in range of [0, 9].
|
|
680
|
+
compression_level (int, optional): Compression level for the resulting PNG file, in range of [0, 9].
|
|
681
|
+
Default: ``6``.
|
|
667
682
|
|
|
668
683
|
Raises:
|
|
669
684
|
TypeError: If `filename` is not of type str.
|
|
@@ -678,6 +693,7 @@ def write_png(filename, image, compression_level=6):
|
|
|
678
693
|
``CPU``
|
|
679
694
|
|
|
680
695
|
Examples:
|
|
696
|
+
>>> import mindspore.dataset.vision as vision
|
|
681
697
|
>>> import numpy as np
|
|
682
698
|
>>> # Generate a random image with height=120, width=340, channels=3
|
|
683
699
|
>>> image = np.random.randint(256, size=(120, 340, 3), dtype=np.uint8)
|
|
@@ -24,7 +24,7 @@ from mindspore.dataset.core.validator_helpers import check_value, check_uint8, F
|
|
|
24
24
|
check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
|
|
25
25
|
parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, UINT8_MIN, check_value_normalize_std, \
|
|
26
26
|
check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \
|
|
27
|
-
check_pos_int32, check_int32, check_tensor_op, deprecator_factory
|
|
27
|
+
check_pos_int32, check_int32, check_tensor_op, deprecator_factory, check_valid_str
|
|
28
28
|
from mindspore.dataset.transforms.validators import check_transform_op_type
|
|
29
29
|
from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
|
|
30
30
|
|
|
@@ -339,6 +339,16 @@ def check_resize_interpolation(method):
|
|
|
339
339
|
|
|
340
340
|
return new_method
|
|
341
341
|
|
|
342
|
+
def check_device_target(method):
|
|
343
|
+
"""A wrapper that wraps a parameter checker"""
|
|
344
|
+
|
|
345
|
+
@wraps(method)
|
|
346
|
+
def new_method(self, *args, **kwargs):
|
|
347
|
+
[device_target], _ = parse_user_args(method, *args, **kwargs)
|
|
348
|
+
check_valid_str(device_target, ["CPU", "Ascend"], "device_target")
|
|
349
|
+
return method(self, *args, **kwargs)
|
|
350
|
+
return new_method
|
|
351
|
+
|
|
342
352
|
|
|
343
353
|
def check_resized_crop(method):
|
|
344
354
|
"""A wrapper that wraps a parameter checker around the original function(ResizedCrop operation)."""
|
|
@@ -715,7 +725,7 @@ def check_pad_to_size(method):
|
|
|
715
725
|
else:
|
|
716
726
|
if len(offset) not in [0, 2]:
|
|
717
727
|
raise ValueError("The offset must be empty or a sequence of length 2.")
|
|
718
|
-
for i,
|
|
728
|
+
for i, _ in enumerate(offset):
|
|
719
729
|
check_non_negative_int32(offset[i], "offset{0}".format(i))
|
|
720
730
|
|
|
721
731
|
check_fill_value(fill_value)
|
|
@@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|
|
17
17
|
|
|
18
18
|
__all__ = ['MapParameter']
|
|
19
19
|
|
|
20
|
+
import os
|
|
20
21
|
import sys
|
|
21
22
|
from copy import copy
|
|
22
23
|
import numbers
|
|
@@ -46,7 +47,7 @@ class MapParameter(Parameter):
|
|
|
46
47
|
default_value (Union[numbers.Number, str]): The default value number or initializer name. Default: 'normal'.
|
|
47
48
|
permit_filter_value (numbers.Number): The permit filter value number. Default: 1.
|
|
48
49
|
evict_filter_value (numbers.Number): The evict filter value number. Default: MAX_SIZE.
|
|
49
|
-
name (str): Name of the map parameter. Default: None
|
|
50
|
+
name (str): Name of the map parameter. Default: ``None``.
|
|
50
51
|
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
|
51
52
|
|
|
52
53
|
|
|
@@ -256,15 +257,28 @@ class MapParameter(Parameter):
|
|
|
256
257
|
|
|
257
258
|
Args:
|
|
258
259
|
incremental (bool): False for full export, otherwise for incremental export. Default: False.
|
|
259
|
-
When exporting data incrementally, the value_array does not contain
|
|
260
|
-
key_array and the length of the
|
|
261
|
-
of the status_array are consistent.
|
|
260
|
+
When exporting data incrementally, the value_array does not contain unchanged data.The length
|
|
261
|
+
of the key_array and the length of the status_array are consistent.
|
|
262
262
|
|
|
263
263
|
Returns:
|
|
264
264
|
Tuple(key_array, value_array, status_array), The exported data as a tuple.
|
|
265
265
|
"""
|
|
266
266
|
return self._map_tensor.export_data(incremental)
|
|
267
267
|
|
|
268
|
+
def export_bytes(self, incremental=False):
|
|
269
|
+
"""
|
|
270
|
+
Export bytes from this map parameter.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
incremental (bool): False for full export, otherwise for incremental export. Default: False.
|
|
274
|
+
When exporting data incrementally, the value_array does not contain unchanged data. The length
|
|
275
|
+
of the key_array and the length of the status_array are consistent.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Tuple(bytes, bytes, bytes), The exported bytes as a tuple.
|
|
279
|
+
"""
|
|
280
|
+
return self._map_tensor.export_bytes(incremental)
|
|
281
|
+
|
|
268
282
|
def import_data(self, data):
|
|
269
283
|
"""
|
|
270
284
|
Import this map parameter from exported data.
|
|
@@ -273,3 +287,23 @@ class MapParameter(Parameter):
|
|
|
273
287
|
data (Tuple): The data tuple with key_array, value_array and status_array.
|
|
274
288
|
"""
|
|
275
289
|
self._map_tensor.import_data(data)
|
|
290
|
+
|
|
291
|
+
def export_slice_data(self, incremental=False):
|
|
292
|
+
"""
|
|
293
|
+
Export a slice data from this map parameter.
|
|
294
|
+
When MapParameter occupies a large memory, only one slice
|
|
295
|
+
of MapParameter is exported at a time (the default slice size is 1GB).
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
incremental (bool): False for full export, otherwise for incremental export. Default: False.
|
|
299
|
+
When exporting data incrementally, the value_array does not contain unchanged data.The length
|
|
300
|
+
of the key_array and the length of the status_array are consistent.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Tuple(key_array, value_array, status_array, last_slice), The exported data as a tuple, and
|
|
304
|
+
the last_slice is bool variable and means whether finish export.
|
|
305
|
+
"""
|
|
306
|
+
enable_persistent = "MS_EMBEDDING_REMOTE_CACHE_MEMORY_SIZE" in os.environ
|
|
307
|
+
if not enable_persistent:
|
|
308
|
+
return self._map_tensor.export_slice_data(incremental)
|
|
309
|
+
return self._map_tensor.export_persistent_slice_data(self.key, incremental)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,10 +11,20 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
#
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Optimizer.
|
|
15
17
|
|
|
16
|
-
|
|
18
|
+
Provide common optimizers for training, such as SGD, ADAM, Momentum.
|
|
19
|
+
The optimizer is used to calculate and update the gradients.
|
|
17
20
|
"""
|
|
18
21
|
from __future__ import absolute_import
|
|
19
22
|
|
|
20
|
-
from mindspore.
|
|
23
|
+
from mindspore.experimental.optim.optimizer import Optimizer
|
|
24
|
+
from mindspore.experimental.optim.adamw import AdamW
|
|
25
|
+
from mindspore.experimental.optim.sgd import SGD
|
|
26
|
+
from mindspore.experimental.optim.adam import Adam
|
|
27
|
+
from mindspore.experimental.optim import lr_scheduler
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
__all__ = ['Optimizer', 'AdamW', 'SGD', 'Adam']
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""adam"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
from mindspore.ops import functional as F, composite as C, operations as P
|
|
19
|
+
from mindspore.common.parameter import Parameter
|
|
20
|
+
from mindspore.common.tensor import Tensor
|
|
21
|
+
import mindspore.common.dtype as mstype
|
|
22
|
+
from mindspore.experimental.optim.optimizer import Optimizer
|
|
23
|
+
|
|
24
|
+
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@_adam_opt.register("Function", "Tensor", "Tensor", "Float", "Float", "Float", "Tensor",
|
|
28
|
+
"Tensor", "Tensor", "Tensor", "Tensor")
|
|
29
|
+
def _run_adam_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
|
|
30
|
+
"""Apply adam optimizer to the weight parameter."""
|
|
31
|
+
success = True
|
|
32
|
+
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient))
|
|
33
|
+
return success
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
|
37
|
+
def _run_adam_with_amsgrad_opt(opt, beta1_power, beta2_power, lr, gradient, param, moment1, moment2, vhat):
|
|
38
|
+
"""Apply adam optimizer to the weight parameter with amsgrad."""
|
|
39
|
+
success = True
|
|
40
|
+
success = F.depend(success, opt(param, moment1, moment2, vhat, beta1_power, beta2_power, lr, gradient))
|
|
41
|
+
return success
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Adam(Optimizer):
|
|
45
|
+
r"""
|
|
46
|
+
Implements Adam algorithm..
|
|
47
|
+
|
|
48
|
+
The updating formulas are as follows:
|
|
49
|
+
|
|
50
|
+
.. math::
|
|
51
|
+
\begin{aligned}
|
|
52
|
+
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
|
53
|
+
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
|
|
54
|
+
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
|
|
55
|
+
\:\textit{maximize} \\
|
|
56
|
+
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
|
57
|
+
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
|
|
58
|
+
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
59
|
+
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
|
60
|
+
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
61
|
+
&\hspace{5mm}\textbf{else} \\
|
|
62
|
+
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
63
|
+
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
|
64
|
+
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
|
65
|
+
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
|
66
|
+
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
|
67
|
+
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
|
68
|
+
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
|
69
|
+
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
|
70
|
+
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
|
71
|
+
\widehat{v_t}) \\
|
|
72
|
+
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
|
73
|
+
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
|
74
|
+
&\hspace{5mm}\textbf{else} \\
|
|
75
|
+
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
|
76
|
+
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
|
77
|
+
&\bf{return} \: \theta_t \\[-1.ex]
|
|
78
|
+
\end{aligned}
|
|
79
|
+
|
|
80
|
+
.. warning::
|
|
81
|
+
This is an experimental optimizer API that is subject to change.
|
|
82
|
+
This module must be used with lr scheduler module in `LRScheduler Class
|
|
83
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.nn.html#learningrateschedule-class>`_ .
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
|
|
87
|
+
parameter groups
|
|
88
|
+
lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
|
|
89
|
+
betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
|
|
90
|
+
Default: ``(0.9, 0.999)``.
|
|
91
|
+
eps (float, optional): term added to the denominator to improve
|
|
92
|
+
numerical stability. Default: ``1e-8``.
|
|
93
|
+
weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
|
|
94
|
+
amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
|
|
95
|
+
|
|
96
|
+
Keyword Args:
|
|
97
|
+
maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
|
|
98
|
+
Default: ``False``.
|
|
99
|
+
|
|
100
|
+
Inputs:
|
|
101
|
+
- **gradients** (tuple[Tensor]) - The gradients of `params`.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
ValueError: If the `lr` is not int, float or Tensor.
|
|
105
|
+
ValueError: If the `lr` is less than 0.
|
|
106
|
+
ValueError: If the `eps` is less than 0.0.
|
|
107
|
+
ValueError: If the `betas` not in the range of 0-1.
|
|
108
|
+
ValueError: If the `weight_decay` is less than 0.
|
|
109
|
+
|
|
110
|
+
Supported Platforms:
|
|
111
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
>>> import mindspore
|
|
115
|
+
>>> from mindspore import nn
|
|
116
|
+
>>> from mindspore.experimental import optim
|
|
117
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
118
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
119
|
+
>>> net = LeNet5()
|
|
120
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
121
|
+
>>> optimizer = optim.Adam(net.trainable_params(), lr=0.1)
|
|
122
|
+
>>> def forward_fn(data, label):
|
|
123
|
+
... logits = net(data)
|
|
124
|
+
... loss = loss_fn(logits, label)
|
|
125
|
+
... return loss, logits
|
|
126
|
+
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
127
|
+
>>> def train_step(data, label):
|
|
128
|
+
... (loss, _), grads = grad_fn(data, label)
|
|
129
|
+
... optimizer(grads)
|
|
130
|
+
... return loss
|
|
131
|
+
"""
|
|
132
|
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
|
133
|
+
weight_decay=0, amsgrad=False, *, maximize=False):
|
|
134
|
+
if lr < 0.0:
|
|
135
|
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
136
|
+
if eps < 0.0:
|
|
137
|
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
138
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
139
|
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
140
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
141
|
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
142
|
+
if weight_decay < 0.0:
|
|
143
|
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
|
144
|
+
|
|
145
|
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
146
|
+
weight_decay=weight_decay, amsgrad=amsgrad,
|
|
147
|
+
maximize=maximize)
|
|
148
|
+
super(Adam, self).__init__(params, defaults)
|
|
149
|
+
|
|
150
|
+
self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
|
|
151
|
+
self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
|
|
152
|
+
self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
|
|
153
|
+
self.state_step = Parameter(Tensor(0, mstype.int32), "state_step")
|
|
154
|
+
self.increase_tensor = Tensor(1, mstype.int32)
|
|
155
|
+
self.assignadd = P.AssignAdd()
|
|
156
|
+
self.op_add = P.AddN()
|
|
157
|
+
self.op_mul = P.Mul()
|
|
158
|
+
self.op_pow = P.Pow()
|
|
159
|
+
self.adam_opt = P.Adam(False, False)
|
|
160
|
+
self.op_cast = P.Cast()
|
|
161
|
+
|
|
162
|
+
def construct(self, gradients):
|
|
163
|
+
self.assignadd(self.state_step, self.increase_tensor)
|
|
164
|
+
for group_id, group in enumerate(self.param_groups):
|
|
165
|
+
start_id = self.group_start_id[group_id]
|
|
166
|
+
end_id = self.group_start_id[group_id+1]
|
|
167
|
+
|
|
168
|
+
lr = group.get("lr")
|
|
169
|
+
weight_decay = group.get("weight_decay")
|
|
170
|
+
beta1, beta2 = group.get("betas")
|
|
171
|
+
maximize = group.get("maximize")
|
|
172
|
+
eps = group.get("eps")
|
|
173
|
+
|
|
174
|
+
beta1_power = self.op_pow(beta1, self.state_step)
|
|
175
|
+
beta2_power = self.op_pow(beta2, self.state_step)
|
|
176
|
+
adam_with_amsgrad_opt = P.ApplyAdamWithAmsgrad(beta1, beta2, eps, False)
|
|
177
|
+
params = self.parameters[start_id: end_id]
|
|
178
|
+
grads = gradients[start_id: end_id]
|
|
179
|
+
grads = grads if not maximize else -grads
|
|
180
|
+
grads = self._decay_weight(weight_decay, params, grads)
|
|
181
|
+
if isinstance(lr, float):
|
|
182
|
+
lr = self.op_cast(group.get("lr"), mstype.float32)
|
|
183
|
+
if group.get("amsgrad"):
|
|
184
|
+
self.hyper_map(F.partial(_adam_opt, adam_with_amsgrad_opt, beta1_power, beta2_power, lr),
|
|
185
|
+
grads, params,
|
|
186
|
+
self.exp_avg[start_id: end_id], self.exp_avg_sq[start_id: end_id],
|
|
187
|
+
self.max_exp_avg_sq[start_id: end_id])
|
|
188
|
+
else:
|
|
189
|
+
self.hyper_map(F.partial(_adam_opt, self.adam_opt, beta1_power, beta2_power, beta1, beta2, eps, lr),
|
|
190
|
+
grads, params,
|
|
191
|
+
self.exp_avg[start_id: end_id], self.exp_avg_sq[start_id: end_id])
|
|
192
|
+
return True
|