mindspore 2.3.0rc1__cp38-none-any.whl → 2.3.0rc2__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +20 -0
- mindspore/_extends/parse/parser.py +1 -1
- mindspore/_extends/parse/standard_method.py +6 -5
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -5
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -2
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +1 -0
- mindspore/common/api.py +56 -4
- mindspore/common/dtype.py +5 -3
- mindspore/common/dump.py +2 -2
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +17 -6
- mindspore/common/parameter.py +7 -2
- mindspore/common/recompute.py +247 -0
- mindspore/common/sparse_tensor.py +2 -2
- mindspore/common/symbol.py +1 -1
- mindspore/common/tensor.py +74 -36
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/management.py +30 -30
- mindspore/context.py +28 -15
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +51 -51
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +3 -3
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +3 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +3 -3
- mindspore/dataset/engine/datasets_vision.py +68 -68
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +26 -26
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/transforms.py +92 -92
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/experimental/optim/adadelta.py +2 -2
- mindspore/experimental/optim/adagrad.py +2 -2
- mindspore/experimental/optim/adam.py +2 -2
- mindspore/experimental/optim/adamax.py +2 -2
- mindspore/experimental/optim/adamw.py +2 -2
- mindspore/experimental/optim/asgd.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +24 -20
- mindspore/experimental/optim/nadam.py +2 -2
- mindspore/experimental/optim/optimizer.py +1 -1
- mindspore/experimental/optim/radam.py +2 -2
- mindspore/experimental/optim/rmsprop.py +2 -2
- mindspore/experimental/optim/rprop.py +2 -2
- mindspore/experimental/optim/sgd.py +2 -2
- mindspore/hal/stream.py +2 -0
- mindspore/include/mindapi/base/types.h +5 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +2 -2
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +4 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +51 -47
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/nn/extend/layer/__init__.py +27 -0
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/basic.py +109 -1
- mindspore/nn/layer/container.py +2 -2
- mindspore/nn/layer/conv.py +6 -6
- mindspore/nn/layer/embedding.py +1 -1
- mindspore/nn/layer/normalization.py +21 -43
- mindspore/nn/layer/padding.py +4 -0
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +7 -7
- mindspore/nn/optim/adamax.py +2 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +1 -1
- mindspore/nn/optim/lamb.py +3 -3
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +2 -2
- mindspore/nn/optim/momentum.py +2 -2
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +2 -2
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/sgd.py +2 -2
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +9 -9
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
- mindspore/ops/_vmap/vmap_math_ops.py +27 -8
- mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
- mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
- mindspore/ops/auto_generate/gen_extend_func.py +274 -0
- mindspore/ops/auto_generate/gen_ops_def.py +889 -22
- mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
- mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
- mindspore/ops/extend/__init__.py +9 -1
- mindspore/ops/extend/array_func.py +134 -27
- mindspore/ops/extend/math_func.py +3 -3
- mindspore/ops/extend/nn_func.py +363 -2
- mindspore/ops/function/__init__.py +19 -2
- mindspore/ops/function/array_func.py +463 -439
- mindspore/ops/function/clip_func.py +7 -18
- mindspore/ops/function/grad/grad_func.py +5 -5
- mindspore/ops/function/linalg_func.py +4 -4
- mindspore/ops/function/math_func.py +260 -243
- mindspore/ops/function/nn_func.py +825 -62
- mindspore/ops/function/random_func.py +73 -4
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +1 -1
- mindspore/ops/functional.py +2 -2
- mindspore/ops/op_info_register.py +1 -31
- mindspore/ops/operations/__init__.py +2 -3
- mindspore/ops/operations/_grad_ops.py +2 -107
- mindspore/ops/operations/_inner_ops.py +5 -5
- mindspore/ops/operations/_sequence_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +11 -233
- mindspore/ops/operations/comm_ops.py +32 -32
- mindspore/ops/operations/custom_ops.py +7 -89
- mindspore/ops/operations/manually_defined/ops_def.py +329 -4
- mindspore/ops/operations/math_ops.py +13 -163
- mindspore/ops/operations/nn_ops.py +9 -316
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +2 -2
- mindspore/ops_generate/arg_dtype_cast.py +12 -3
- mindspore/ops_generate/arg_handler.py +24 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
- mindspore/ops_generate/gen_pyboost_func.py +13 -6
- mindspore/ops_generate/pyboost_utils.py +2 -17
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +106 -1
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_utils.py +16 -0
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/checkpoint_transform.py +249 -77
- mindspore/parallel/cluster/process_entity/_api.py +1 -1
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +1 -1
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
- mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
- mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
- mindspore/profiler/parser/ascend_op_generator.py +26 -9
- mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
- mindspore/profiler/parser/profiler_info.py +11 -1
- mindspore/profiler/profiling.py +13 -5
- mindspore/rewrite/api/node.py +12 -12
- mindspore/rewrite/api/symbol_tree.py +11 -11
- mindspore/run_check/_check_version.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +2 -2
- mindspore/train/amp.py +4 -4
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +2 -2
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +2 -2
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/dataset_helper.py +8 -3
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/mind_ir_pb2.py +22 -17
- mindspore/train/model.py +15 -15
- mindspore/train/serialization.py +18 -18
- mindspore/train/summary/summary_record.py +7 -7
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +309 -262
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
- /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""mint nn functional."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from mindspore.ops.extend import max_pool2d
|
|
18
|
+
from mindspore.ops.functional import (
|
|
19
|
+
conv_transpose2d,
|
|
20
|
+
grid_sample
|
|
21
|
+
)
|
|
22
|
+
# 1
|
|
23
|
+
|
|
24
|
+
# 2
|
|
25
|
+
|
|
26
|
+
# 3
|
|
27
|
+
|
|
28
|
+
# 4
|
|
29
|
+
|
|
30
|
+
# 5
|
|
31
|
+
from mindspore.ops.functional import pad_ext as pad
|
|
32
|
+
# 6
|
|
33
|
+
|
|
34
|
+
# 7
|
|
35
|
+
|
|
36
|
+
# 8
|
|
37
|
+
from mindspore.ops.functional import layer_norm
|
|
38
|
+
# 9
|
|
39
|
+
from mindspore.ops.function.nn_func import interpolate_ext as interpolate
|
|
40
|
+
# 10
|
|
41
|
+
|
|
42
|
+
# 11
|
|
43
|
+
from mindspore.ops.functional import relu
|
|
44
|
+
# 12
|
|
45
|
+
|
|
46
|
+
# 13
|
|
47
|
+
|
|
48
|
+
# 14
|
|
49
|
+
from mindspore.ops.function.nn_func import dropout_ext as dropout
|
|
50
|
+
# 15
|
|
51
|
+
|
|
52
|
+
# 16
|
|
53
|
+
|
|
54
|
+
# 17
|
|
55
|
+
|
|
56
|
+
# 18
|
|
57
|
+
|
|
58
|
+
# 19
|
|
59
|
+
|
|
60
|
+
# 20
|
|
61
|
+
|
|
62
|
+
# 21
|
|
63
|
+
|
|
64
|
+
# 22
|
|
65
|
+
|
|
66
|
+
# 23
|
|
67
|
+
|
|
68
|
+
# 24
|
|
69
|
+
|
|
70
|
+
# 25
|
|
71
|
+
|
|
72
|
+
# 26
|
|
73
|
+
|
|
74
|
+
# 27
|
|
75
|
+
|
|
76
|
+
# 28
|
|
77
|
+
|
|
78
|
+
# 29
|
|
79
|
+
|
|
80
|
+
# 30
|
|
81
|
+
|
|
82
|
+
# 31
|
|
83
|
+
|
|
84
|
+
# 32
|
|
85
|
+
|
|
86
|
+
# 33
|
|
87
|
+
|
|
88
|
+
# 34
|
|
89
|
+
|
|
90
|
+
# 35
|
|
91
|
+
|
|
92
|
+
# 36
|
|
93
|
+
from mindspore.ops.functional import gelu
|
|
94
|
+
# 37
|
|
95
|
+
|
|
96
|
+
# 38
|
|
97
|
+
|
|
98
|
+
# 39
|
|
99
|
+
from mindspore.ops.functional import group_norm
|
|
100
|
+
# 40
|
|
101
|
+
|
|
102
|
+
# 41
|
|
103
|
+
|
|
104
|
+
# 42
|
|
105
|
+
|
|
106
|
+
# 43
|
|
107
|
+
|
|
108
|
+
# 44
|
|
109
|
+
|
|
110
|
+
# 45
|
|
111
|
+
|
|
112
|
+
# 46
|
|
113
|
+
from mindspore.ops.functional import silu
|
|
114
|
+
# 47
|
|
115
|
+
|
|
116
|
+
# 48
|
|
117
|
+
|
|
118
|
+
# 49
|
|
119
|
+
from mindspore.ops.functional import sigmoid
|
|
120
|
+
# 50
|
|
121
|
+
|
|
122
|
+
# 51
|
|
123
|
+
|
|
124
|
+
# 52
|
|
125
|
+
from mindspore.ops.functional import embedding
|
|
126
|
+
# 53
|
|
127
|
+
|
|
128
|
+
# 54
|
|
129
|
+
|
|
130
|
+
# 55
|
|
131
|
+
|
|
132
|
+
# 56
|
|
133
|
+
|
|
134
|
+
# 57
|
|
135
|
+
|
|
136
|
+
# 58
|
|
137
|
+
|
|
138
|
+
# 59
|
|
139
|
+
|
|
140
|
+
# 60
|
|
141
|
+
|
|
142
|
+
# 61
|
|
143
|
+
|
|
144
|
+
# 62
|
|
145
|
+
|
|
146
|
+
# 63
|
|
147
|
+
|
|
148
|
+
# 64
|
|
149
|
+
|
|
150
|
+
# 65
|
|
151
|
+
|
|
152
|
+
# 66
|
|
153
|
+
|
|
154
|
+
# 67
|
|
155
|
+
|
|
156
|
+
# 68
|
|
157
|
+
|
|
158
|
+
# 69
|
|
159
|
+
|
|
160
|
+
# 70
|
|
161
|
+
|
|
162
|
+
# 71
|
|
163
|
+
|
|
164
|
+
# 72
|
|
165
|
+
|
|
166
|
+
# 73
|
|
167
|
+
|
|
168
|
+
# 74
|
|
169
|
+
|
|
170
|
+
# 75
|
|
171
|
+
|
|
172
|
+
# 76
|
|
173
|
+
|
|
174
|
+
# 77
|
|
175
|
+
|
|
176
|
+
# 78
|
|
177
|
+
|
|
178
|
+
# 79
|
|
179
|
+
|
|
180
|
+
# 80
|
|
181
|
+
|
|
182
|
+
# 81
|
|
183
|
+
|
|
184
|
+
# 82
|
|
185
|
+
|
|
186
|
+
# 83
|
|
187
|
+
|
|
188
|
+
# 84
|
|
189
|
+
|
|
190
|
+
# 85
|
|
191
|
+
|
|
192
|
+
# 86
|
|
193
|
+
|
|
194
|
+
# 87
|
|
195
|
+
|
|
196
|
+
# 88
|
|
197
|
+
|
|
198
|
+
# 89
|
|
199
|
+
|
|
200
|
+
# 90
|
|
201
|
+
from mindspore.ops.function.nn_func import avg_pool2d_ext as avg_pool2d
|
|
202
|
+
# 91
|
|
203
|
+
|
|
204
|
+
# 92
|
|
205
|
+
from mindspore.ops.extend import leaky_relu_ext as leaky_relu
|
|
206
|
+
# 93
|
|
207
|
+
from mindspore.ops.function.nn_func import softplus_ext as softplus
|
|
208
|
+
# 94
|
|
209
|
+
from mindspore.ops.function.math_func import tanh
|
|
210
|
+
# 95
|
|
211
|
+
|
|
212
|
+
# 96
|
|
213
|
+
|
|
214
|
+
# 97
|
|
215
|
+
|
|
216
|
+
# 98
|
|
217
|
+
|
|
218
|
+
# 99
|
|
219
|
+
|
|
220
|
+
# 100
|
|
221
|
+
|
|
222
|
+
__all__ = [
|
|
223
|
+
'conv_transpose2d',
|
|
224
|
+
'max_pool2d',
|
|
225
|
+
# 1
|
|
226
|
+
|
|
227
|
+
# 2
|
|
228
|
+
|
|
229
|
+
# 3
|
|
230
|
+
|
|
231
|
+
# 4
|
|
232
|
+
|
|
233
|
+
# 5
|
|
234
|
+
'pad',
|
|
235
|
+
# 6
|
|
236
|
+
|
|
237
|
+
# 7
|
|
238
|
+
|
|
239
|
+
# 8
|
|
240
|
+
'layer_norm',
|
|
241
|
+
# 9
|
|
242
|
+
'interpolate',
|
|
243
|
+
# 10
|
|
244
|
+
|
|
245
|
+
# 11
|
|
246
|
+
'relu',
|
|
247
|
+
# 12
|
|
248
|
+
|
|
249
|
+
# 13
|
|
250
|
+
|
|
251
|
+
# 14
|
|
252
|
+
'dropout',
|
|
253
|
+
# 15
|
|
254
|
+
|
|
255
|
+
# 16
|
|
256
|
+
|
|
257
|
+
# 17
|
|
258
|
+
|
|
259
|
+
# 18
|
|
260
|
+
|
|
261
|
+
# 19
|
|
262
|
+
|
|
263
|
+
# 20
|
|
264
|
+
|
|
265
|
+
# 21
|
|
266
|
+
|
|
267
|
+
# 22
|
|
268
|
+
|
|
269
|
+
# 23
|
|
270
|
+
|
|
271
|
+
# 24
|
|
272
|
+
|
|
273
|
+
# 25
|
|
274
|
+
|
|
275
|
+
# 26
|
|
276
|
+
|
|
277
|
+
# 27
|
|
278
|
+
|
|
279
|
+
# 28
|
|
280
|
+
|
|
281
|
+
# 29
|
|
282
|
+
|
|
283
|
+
# 30
|
|
284
|
+
|
|
285
|
+
# 31
|
|
286
|
+
|
|
287
|
+
# 32
|
|
288
|
+
|
|
289
|
+
# 33
|
|
290
|
+
|
|
291
|
+
# 34
|
|
292
|
+
|
|
293
|
+
# 35
|
|
294
|
+
|
|
295
|
+
# 36
|
|
296
|
+
'gelu',
|
|
297
|
+
# 37
|
|
298
|
+
|
|
299
|
+
# 38
|
|
300
|
+
|
|
301
|
+
# 39
|
|
302
|
+
'group_norm',
|
|
303
|
+
# 40
|
|
304
|
+
|
|
305
|
+
# 41
|
|
306
|
+
|
|
307
|
+
# 42
|
|
308
|
+
|
|
309
|
+
# 43
|
|
310
|
+
|
|
311
|
+
# 44
|
|
312
|
+
|
|
313
|
+
# 45
|
|
314
|
+
|
|
315
|
+
# 46
|
|
316
|
+
'silu',
|
|
317
|
+
# 47
|
|
318
|
+
|
|
319
|
+
# 48
|
|
320
|
+
|
|
321
|
+
# 49
|
|
322
|
+
'sigmoid',
|
|
323
|
+
# 50
|
|
324
|
+
|
|
325
|
+
# 51
|
|
326
|
+
|
|
327
|
+
# 52
|
|
328
|
+
'embedding',
|
|
329
|
+
# 53
|
|
330
|
+
|
|
331
|
+
# 54
|
|
332
|
+
|
|
333
|
+
# 55
|
|
334
|
+
|
|
335
|
+
# 56
|
|
336
|
+
|
|
337
|
+
# 57
|
|
338
|
+
|
|
339
|
+
# 58
|
|
340
|
+
|
|
341
|
+
# 59
|
|
342
|
+
|
|
343
|
+
# 60
|
|
344
|
+
|
|
345
|
+
# 61
|
|
346
|
+
|
|
347
|
+
# 62
|
|
348
|
+
|
|
349
|
+
# 63
|
|
350
|
+
|
|
351
|
+
# 64
|
|
352
|
+
|
|
353
|
+
# 65
|
|
354
|
+
|
|
355
|
+
# 66
|
|
356
|
+
|
|
357
|
+
# 67
|
|
358
|
+
|
|
359
|
+
# 68
|
|
360
|
+
|
|
361
|
+
# 69
|
|
362
|
+
|
|
363
|
+
# 70
|
|
364
|
+
|
|
365
|
+
# 71
|
|
366
|
+
|
|
367
|
+
# 72
|
|
368
|
+
|
|
369
|
+
# 73
|
|
370
|
+
|
|
371
|
+
# 74
|
|
372
|
+
|
|
373
|
+
# 75
|
|
374
|
+
|
|
375
|
+
# 76
|
|
376
|
+
|
|
377
|
+
# 77
|
|
378
|
+
|
|
379
|
+
# 78
|
|
380
|
+
|
|
381
|
+
# 79
|
|
382
|
+
|
|
383
|
+
# 80
|
|
384
|
+
|
|
385
|
+
# 81
|
|
386
|
+
|
|
387
|
+
# 82
|
|
388
|
+
|
|
389
|
+
# 83
|
|
390
|
+
|
|
391
|
+
# 84
|
|
392
|
+
|
|
393
|
+
# 85
|
|
394
|
+
|
|
395
|
+
# 86
|
|
396
|
+
|
|
397
|
+
# 87
|
|
398
|
+
|
|
399
|
+
# 88
|
|
400
|
+
|
|
401
|
+
# 89
|
|
402
|
+
|
|
403
|
+
# 90
|
|
404
|
+
'avg_pool2d',
|
|
405
|
+
# 91
|
|
406
|
+
'grid_sample',
|
|
407
|
+
# 92
|
|
408
|
+
'leaky_relu',
|
|
409
|
+
# 93
|
|
410
|
+
'softplus',
|
|
411
|
+
# 94
|
|
412
|
+
'tanh',
|
|
413
|
+
# 95
|
|
414
|
+
|
|
415
|
+
# 96
|
|
416
|
+
|
|
417
|
+
# 97
|
|
418
|
+
|
|
419
|
+
# 98
|
|
420
|
+
|
|
421
|
+
# 99
|
|
422
|
+
|
|
423
|
+
# 100
|
|
424
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Optimizer.
|
|
17
|
+
|
|
18
|
+
Provide common optimizers for training, such as AdamW.
|
|
19
|
+
The optimizer is used to calculate and update the gradients.
|
|
20
|
+
"""
|
|
21
|
+
from __future__ import absolute_import
|
|
22
|
+
from mindspore.mint.optim.adamw import AdamW
|
|
23
|
+
|
|
24
|
+
__all__ = ['AdamW']
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""adamw"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
from mindspore.ops import functional as F, composite as C, operations as P
|
|
19
|
+
from mindspore.common.parameter import Parameter
|
|
20
|
+
from mindspore.common.tensor import Tensor
|
|
21
|
+
from mindspore.common import dtype as mstype
|
|
22
|
+
from mindspore.ops import auto_generate as gen
|
|
23
|
+
from mindspore.experimental.optim.optimizer import Optimizer
|
|
24
|
+
from mindspore import _checkparam as validator
|
|
25
|
+
|
|
26
|
+
_optim_adamw_opt = C.MultitypeFuncGraph("optim_adamw_opt")
|
|
27
|
+
hyper_map = C.HyperMap()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@_optim_adamw_opt.register("Function", "Float", "Float", "Float", "Float", "Float", "Tensor", "Bool", "Bool", "Tensor",
|
|
31
|
+
"Tensor", "Tensor", "Tensor", "Tensor")
|
|
32
|
+
def _run_optim_adamw_opt(opt, beta1, beta2, lr, eps, weight_decay, step, amsgrad, maximize, parameters, grads, exp_avg,
|
|
33
|
+
exp_avg_sq, max_exp_avg_sq):
|
|
34
|
+
"""Apply adamw optimizer to the weight parameter."""
|
|
35
|
+
success = True
|
|
36
|
+
opt(parameters, exp_avg, exp_avg_sq, max_exp_avg_sq, P.Cast()(grads, F.dtype(parameters)), step, lr, beta1, beta2,
|
|
37
|
+
weight_decay, eps, amsgrad, maximize)
|
|
38
|
+
return success
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _check_param_value(betas, eps, weight_decay, lr, amsgrad, maximize, prim_name):
|
|
42
|
+
"""Check the type of inputs."""
|
|
43
|
+
validator.check_value_type('betas', betas, [tuple], prim_name)
|
|
44
|
+
validator.check("betas size", len(betas), "", [2], validator.IN, prim_name)
|
|
45
|
+
validator.check_value_type("betas[0]", betas[0], [float], prim_name)
|
|
46
|
+
validator.check_value_type("betas[1]", betas[1], [float], prim_name)
|
|
47
|
+
validator.check_value_type("eps", eps, [float], prim_name)
|
|
48
|
+
validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
|
|
49
|
+
validator.check_value_type("lr", lr, [float], prim_name)
|
|
50
|
+
validator.check_value_type("amsgrad", amsgrad, [bool], prim_name)
|
|
51
|
+
validator.check_value_type("maximize", maximize, [bool], prim_name)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class AdamW(Optimizer):
|
|
55
|
+
r"""
|
|
56
|
+
Implements Adam Weight Decay algorithm.
|
|
57
|
+
|
|
58
|
+
.. math::
|
|
59
|
+
\begin{aligned}
|
|
60
|
+
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
|
|
61
|
+
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
|
|
62
|
+
\: \epsilon \text{ (epsilon)} \\
|
|
63
|
+
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
|
|
64
|
+
\: \textit{maximize} \\
|
|
65
|
+
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
|
|
66
|
+
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
|
|
67
|
+
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
68
|
+
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
|
69
|
+
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
70
|
+
&\hspace{5mm}\textbf{else} \\
|
|
71
|
+
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
72
|
+
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
|
|
73
|
+
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
|
74
|
+
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
|
75
|
+
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
|
76
|
+
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
|
77
|
+
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
|
78
|
+
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
|
79
|
+
\widehat{v_t}) \\
|
|
80
|
+
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
|
81
|
+
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
|
82
|
+
&\hspace{5mm}\textbf{else} \\
|
|
83
|
+
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
|
84
|
+
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
|
85
|
+
&\bf{return} \: \theta_t \\[-1.ex]
|
|
86
|
+
\end{aligned}
|
|
87
|
+
|
|
88
|
+
.. warning::
|
|
89
|
+
This is an experimental optimizer API that is subject to change.
|
|
90
|
+
This module must be used with lr scheduler module in `LRScheduler Class
|
|
91
|
+
<https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
|
|
95
|
+
parameter groups
|
|
96
|
+
lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
|
|
97
|
+
betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
|
|
98
|
+
Default: ``(0.9, 0.999)``.
|
|
99
|
+
eps (float, optional): term added to the denominator to improve
|
|
100
|
+
numerical stability. Default: ``1e-8``.
|
|
101
|
+
weight_decay (float, optional): weight decay (L2 penalty). Default: ``0.``.
|
|
102
|
+
amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
|
|
103
|
+
|
|
104
|
+
Keyword Args:
|
|
105
|
+
maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
|
|
106
|
+
Default: ``False``.
|
|
107
|
+
|
|
108
|
+
Inputs:
|
|
109
|
+
- **gradients** (tuple[Tensor]) - The gradients of `params`.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If the learning rate is not int, float or Tensor.
|
|
113
|
+
ValueError: If the learning rate is less than 0.
|
|
114
|
+
ValueError: If the `eps` is less than 0.0.
|
|
115
|
+
ValueError: If the `betas` not in the range of [0, 1).
|
|
116
|
+
ValueError: If the `weight_decay` is less than 0.
|
|
117
|
+
|
|
118
|
+
Supported Platforms:
|
|
119
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
120
|
+
|
|
121
|
+
Examples:
|
|
122
|
+
>>> import mindspore
|
|
123
|
+
>>> from mindspore import nn
|
|
124
|
+
>>> from mindspore.mint import optim
|
|
125
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
126
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
127
|
+
>>> net = LeNet5()
|
|
128
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
129
|
+
>>> optimizer = optim.AdamW(net.trainable_params(), lr=0.1)
|
|
130
|
+
>>> def forward_fn(data, label):
|
|
131
|
+
... logits = net(data)
|
|
132
|
+
... loss = loss_fn(logits, label)
|
|
133
|
+
... return loss, logits
|
|
134
|
+
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
135
|
+
>>> def train_step(data, label):
|
|
136
|
+
... (loss, _), grads = grad_fn(data, label)
|
|
137
|
+
... optimizer(grads)
|
|
138
|
+
... return loss
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
|
142
|
+
weight_decay=1e-2, amsgrad=False, *, maximize=False):
|
|
143
|
+
_check_param_value(betas, eps, weight_decay, lr, amsgrad, maximize, self.cls_name)
|
|
144
|
+
if lr < 0.0:
|
|
145
|
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
146
|
+
if eps < 0.0:
|
|
147
|
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
148
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
149
|
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
150
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
151
|
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
152
|
+
if weight_decay < 0.0:
|
|
153
|
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
|
154
|
+
|
|
155
|
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
156
|
+
weight_decay=weight_decay, amsgrad=amsgrad,
|
|
157
|
+
maximize=maximize)
|
|
158
|
+
super(AdamW, self).__init__(params, defaults)
|
|
159
|
+
|
|
160
|
+
self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
|
|
161
|
+
self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
|
|
162
|
+
self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
|
|
163
|
+
self.state_step = Parameter(Tensor([0], mstype.float32), "state_step")
|
|
164
|
+
self.increase_tensor = Tensor(1, mstype.float32)
|
|
165
|
+
self.assignadd = P.AssignAdd()
|
|
166
|
+
self.op_cast = P.Cast()
|
|
167
|
+
self.adamw_opt = gen.AdamWeightDecayExt()
|
|
168
|
+
|
|
169
|
+
def construct(self, gradients):
|
|
170
|
+
self.assignadd(self.state_step, self.increase_tensor)
|
|
171
|
+
for group_id, group in enumerate(self.param_groups):
|
|
172
|
+
beta1, beta2 = group['betas']
|
|
173
|
+
maximize = group.get("maximize")
|
|
174
|
+
start_id = self.group_start_id[group_id]
|
|
175
|
+
end_id = self.group_start_id[group_id + 1]
|
|
176
|
+
lr = self.lrs[group_id]
|
|
177
|
+
if isinstance(group.get("lr"), float):
|
|
178
|
+
lr = self.op_cast(group.get("lr"), mstype.float32)
|
|
179
|
+
grads = tuple([grad if not maximize else F.neg(grad) for grad in gradients[start_id: end_id]])
|
|
180
|
+
|
|
181
|
+
self.hyper_map(F.partial(_optim_adamw_opt, self.adamw_opt, beta1, beta2, float(lr),
|
|
182
|
+
group.get("eps"), group.get("weight_decay"), self.state_step,
|
|
183
|
+
group.get("amsgrad"), maximize),
|
|
184
|
+
self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
|
|
185
|
+
self.exp_avg_sq[start_id: end_id], self.max_exp_avg_sq[start_id: end_id])
|
|
186
|
+
return True
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
mindspore.multiprocessing is a wrapper around the native `multiprocessing` module.
|
|
17
17
|
Some methods are overrode to support fork-based multiprocess.
|
|
18
18
|
"""
|
|
19
|
+
import types
|
|
19
20
|
import signal
|
|
20
21
|
import multiprocessing as mp
|
|
21
22
|
from multiprocessing import *
|
|
@@ -64,5 +65,8 @@ class Pool(mp.pool.Pool): # pylint: disable=function-redefined, abstract-method
|
|
|
64
65
|
"""
|
|
65
66
|
def Process(self, *args, **kwds):
|
|
66
67
|
if self._ctx.get_start_method() == "fork":
|
|
68
|
+
# Process() becomes a staticmethod function of Pool with first argument 'ctx' in python 3.8.0 and later
|
|
69
|
+
if isinstance(super().Process, types.FunctionType):
|
|
70
|
+
args = args[1:]
|
|
67
71
|
return _MsProcess(*args, **kwds)
|
|
68
72
|
return super().Process(*args, **kwds)
|
mindspore/nn/__init__.py
CHANGED
|
@@ -21,6 +21,7 @@ from __future__ import absolute_import
|
|
|
21
21
|
|
|
22
22
|
from mindspore.nn import layer, loss, optim, wrap, grad, metrics, probability, sparse, dynamic_lr, reinforcement
|
|
23
23
|
from mindspore.nn.learning_rate_schedule import *
|
|
24
|
+
from mindspore.nn.generator import *
|
|
24
25
|
from mindspore.nn.dynamic_lr import *
|
|
25
26
|
from mindspore.nn.cell import Cell, GraphCell
|
|
26
27
|
from mindspore.nn.layer import *
|
|
@@ -31,6 +32,7 @@ from mindspore.nn.wrap import *
|
|
|
31
32
|
from mindspore.nn.grad import Jvp, Vjp
|
|
32
33
|
from mindspore.nn.sparse import *
|
|
33
34
|
from mindspore.nn.reinforcement import *
|
|
35
|
+
from mindspore.nn import extend
|
|
34
36
|
|
|
35
37
|
__all__ = ["Cell", "GraphCell"]
|
|
36
38
|
__all__.extend(layer.__all__)
|
|
@@ -43,5 +45,6 @@ __all__.extend(sparse.__all__)
|
|
|
43
45
|
__all__.extend(learning_rate_schedule.__all__)
|
|
44
46
|
__all__.extend(dynamic_lr.__all__)
|
|
45
47
|
__all__.extend(reinforcement.__all__)
|
|
48
|
+
__all__.extend(generator.__all__)
|
|
46
49
|
|
|
47
50
|
__all__.sort()
|