mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -26,12 +26,12 @@ from mindspore.common.tensor import Tensor
|
|
|
26
26
|
from mindspore.common.parameter import Parameter
|
|
27
27
|
from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \
|
|
28
28
|
HeUniform, Uniform, _calculate_fan_in_and_fan_out
|
|
29
|
-
from mindspore.ops.function.nn_func import multi_head_attention_forward
|
|
30
29
|
from mindspore.nn.cell import Cell
|
|
31
30
|
from .basic import Dense, Dropout
|
|
32
31
|
from .activation import ReLU, GELU
|
|
33
32
|
from .normalization import LayerNorm
|
|
34
33
|
from .container import CellList
|
|
34
|
+
|
|
35
35
|
__all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderLayer',
|
|
36
36
|
'TransformerEncoder', 'TransformerDecoder', 'Transformer']
|
|
37
37
|
|
|
@@ -54,16 +54,16 @@ class MultiheadAttention(Cell):
|
|
|
54
54
|
embed_dim (int): Total dimension of MultiheadAttention.
|
|
55
55
|
num_heads (int): Number of attention heads. Note that `embed_dim` will be split
|
|
56
56
|
across `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`).
|
|
57
|
-
dropout (float): Dropout probability of `attn_output_weights`. Default: ``0.0``.
|
|
58
|
-
has_bias (bool): Whether adds bias to input / output projection layers. Default: ``True``.
|
|
59
|
-
add_bias_kv (bool): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
|
|
60
|
-
add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at axis=1.
|
|
57
|
+
dropout (float, optional): Dropout probability of `attn_output_weights`. Default: ``0.0``.
|
|
58
|
+
has_bias (bool, optional): Whether adds bias to input / output projection layers. Default: ``True``.
|
|
59
|
+
add_bias_kv (bool, optional): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
|
|
60
|
+
add_zero_attn (bool, optional): Whether adds a new batch of zeros to the key and value sequences at axis=1.
|
|
61
61
|
Default: ``False``.
|
|
62
|
-
kdim (int): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
|
|
63
|
-
vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
|
|
64
|
-
batch_first (bool): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
|
|
62
|
+
kdim (int, optional): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
|
|
63
|
+
vdim (int, optional): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
|
|
64
|
+
batch_first (bool, optional): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
|
|
65
65
|
else :math:`(seq, batch, feature)` . Default: ``False``.
|
|
66
|
-
dtype (:class:`mindspore.dtype
|
|
66
|
+
dtype (:class:`mindspore.dtype`, optional): Data type of Parameter. Default: ``mstype.float32`` .
|
|
67
67
|
|
|
68
68
|
Inputs:
|
|
69
69
|
- **query** (Tensor) - The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
|
|
@@ -85,7 +85,7 @@ class MultiheadAttention(Cell):
|
|
|
85
85
|
For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
|
|
86
86
|
the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
|
|
87
87
|
Supported float types: float16, float32, float64. Default: ``None``.
|
|
88
|
-
- **need_weights** (bool) - Whether returns `attn_output_weights` in addition to `attn_outputs`.
|
|
88
|
+
- **need_weights** (bool, optional) - Whether returns `attn_output_weights` in addition to `attn_outputs`.
|
|
89
89
|
Default: ``True``.
|
|
90
90
|
- **attn_mask** (Tensor, optional) - If specified, a 2D or 3D mask preventing attention to certain positions.
|
|
91
91
|
Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num_heads}, L, S)`, where :math:`N` is the
|
|
@@ -94,7 +94,8 @@ class MultiheadAttention(Cell):
|
|
|
94
94
|
in the batch. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed
|
|
95
95
|
to attend. For a float mask, the mask values will be added to the attention weight.
|
|
96
96
|
Supported float types: float16, float32, float64. Default: ``None``.
|
|
97
|
-
- **average_attn_weights** (bool) - If true, indicates that
|
|
97
|
+
- **average_attn_weights** (bool, optional) - If true, indicates that
|
|
98
|
+
the returned `attn_weights` should be averaged
|
|
98
99
|
across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
|
|
99
100
|
has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
|
|
100
101
|
|
|
@@ -212,7 +213,7 @@ class MultiheadAttention(Cell):
|
|
|
212
213
|
query, key, value = [x.swapaxes(1, 0) for x in (query, key, value)]
|
|
213
214
|
|
|
214
215
|
if not self._qkv_same_embed_dim:
|
|
215
|
-
attn_output, attn_output_weights = multi_head_attention_forward(
|
|
216
|
+
attn_output, attn_output_weights = ops.function.nn_func.multi_head_attention_forward(
|
|
216
217
|
query, key, value, self.embed_dim, self.num_heads,
|
|
217
218
|
self.in_proj_weight, self.in_proj_bias,
|
|
218
219
|
self.bias_k, self.bias_v, self.add_zero_attn,
|
|
@@ -224,7 +225,7 @@ class MultiheadAttention(Cell):
|
|
|
224
225
|
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
|
|
225
226
|
k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
|
|
226
227
|
else:
|
|
227
|
-
attn_output, attn_output_weights = multi_head_attention_forward(
|
|
228
|
+
attn_output, attn_output_weights = ops.function.nn_func.multi_head_attention_forward(
|
|
228
229
|
query, key, value, self.embed_dim, self.num_heads,
|
|
229
230
|
self.in_proj_weight, self.in_proj_bias,
|
|
230
231
|
self.bias_k, self.bias_v, self.add_zero_attn,
|
|
@@ -328,7 +329,7 @@ class TransformerEncoderLayer(Cell):
|
|
|
328
329
|
self.activation1 = activation
|
|
329
330
|
|
|
330
331
|
if not isinstance(activation, str) and not isinstance(activation, Cell) \
|
|
331
|
-
|
|
332
|
+
and not callable(activation):
|
|
332
333
|
raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
|
|
333
334
|
f" but get {activation}.")
|
|
334
335
|
if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
|
|
@@ -360,15 +361,23 @@ class TransformerEncoderLayer(Cell):
|
|
|
360
361
|
raise AssertionError(
|
|
361
362
|
"only bool and floating types of key_padding_mask are supported")
|
|
362
363
|
|
|
363
|
-
|
|
364
|
+
input_data = src
|
|
365
|
+
|
|
364
366
|
if self.norm_first:
|
|
365
|
-
|
|
366
|
-
|
|
367
|
+
normed_input = self.norm1(input_data)
|
|
368
|
+
sa_block_result = self._sa_block(normed_input, src_mask, src_key_padding_mask)
|
|
369
|
+
input_data = input_data + sa_block_result
|
|
370
|
+
normed_updated_input = self.norm2(input_data)
|
|
371
|
+
ff_block_result = self._ff_block(normed_updated_input)
|
|
372
|
+
input_data = input_data + ff_block_result
|
|
367
373
|
else:
|
|
368
|
-
|
|
369
|
-
|
|
374
|
+
sa_block_result = self._sa_block(input_data, src_mask, src_key_padding_mask)
|
|
375
|
+
normed_sa_result = self.norm1(input_data + sa_block_result)
|
|
376
|
+
input_data = normed_sa_result
|
|
377
|
+
ff_block_result = self._ff_block(input_data)
|
|
378
|
+
input_data = self.norm2(input_data + ff_block_result)
|
|
370
379
|
|
|
371
|
-
return
|
|
380
|
+
return input_data
|
|
372
381
|
|
|
373
382
|
def _sa_block(self, x, attn_mask, key_padding_mask):
|
|
374
383
|
x = self.self_attn(x, x, x,
|
|
@@ -480,7 +489,7 @@ class TransformerDecoderLayer(Cell):
|
|
|
480
489
|
self.activation1 = activation
|
|
481
490
|
|
|
482
491
|
if not isinstance(activation, str) and not isinstance(activation, Cell) \
|
|
483
|
-
|
|
492
|
+
and not callable(activation):
|
|
484
493
|
raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
|
|
485
494
|
f" but get {activation}.")
|
|
486
495
|
if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
|
|
@@ -507,17 +516,29 @@ class TransformerDecoderLayer(Cell):
|
|
|
507
516
|
def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
|
508
517
|
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
|
509
518
|
memory_key_padding_mask: Optional[Tensor] = None):
|
|
510
|
-
|
|
519
|
+
input_data = tgt
|
|
520
|
+
|
|
511
521
|
if self.norm_first:
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
522
|
+
normed_input = self.norm1(input_data)
|
|
523
|
+
sa_block_result = self._sa_block(normed_input, tgt_mask, tgt_key_padding_mask)
|
|
524
|
+
input_data = input_data + sa_block_result
|
|
525
|
+
normed_updated_input_1 = self.norm2(input_data)
|
|
526
|
+
mha_block_result = self._mha_block(normed_updated_input_1, memory, memory_mask, memory_key_padding_mask)
|
|
527
|
+
input_data = input_data + mha_block_result
|
|
528
|
+
normed_updated_input_2 = self.norm3(input_data)
|
|
529
|
+
ff_block_result = self._ff_block(normed_updated_input_2)
|
|
530
|
+
input_data = input_data + ff_block_result
|
|
515
531
|
else:
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
532
|
+
sa_block_result = self._sa_block(input_data, tgt_mask, tgt_key_padding_mask)
|
|
533
|
+
normed_sa_result = self.norm1(input_data + sa_block_result)
|
|
534
|
+
input_data = normed_sa_result
|
|
535
|
+
mha_block_result = self._mha_block(input_data, memory, memory_mask, memory_key_padding_mask)
|
|
536
|
+
normed_mha_result = self.norm2(input_data + mha_block_result)
|
|
537
|
+
input_data = normed_mha_result
|
|
538
|
+
ff_block_result = self._ff_block(input_data)
|
|
539
|
+
input_data = self.norm3(input_data + ff_block_result)
|
|
519
540
|
|
|
520
|
-
return
|
|
541
|
+
return input_data
|
|
521
542
|
|
|
522
543
|
def _sa_block(self, x, attn_mask, key_padding_mask):
|
|
523
544
|
x = self.self_attn(x, x, x,
|
|
@@ -670,17 +691,19 @@ class TransformerDecoder(Cell):
|
|
|
670
691
|
def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
|
671
692
|
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
|
672
693
|
memory_key_padding_mask: Optional[Tensor] = None):
|
|
673
|
-
|
|
694
|
+
processed_output = tgt
|
|
674
695
|
for mod in self.layers:
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
696
|
+
layer_output = mod(processed_output, memory,
|
|
697
|
+
tgt_mask=tgt_mask,
|
|
698
|
+
memory_mask=memory_mask,
|
|
699
|
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
700
|
+
memory_key_padding_mask=memory_key_padding_mask)
|
|
701
|
+
processed_output = layer_output
|
|
679
702
|
|
|
680
703
|
if self.norm is not None:
|
|
681
|
-
|
|
704
|
+
processed_output = self.norm(processed_output)
|
|
682
705
|
|
|
683
|
-
return
|
|
706
|
+
return processed_output
|
|
684
707
|
|
|
685
708
|
|
|
686
709
|
class Transformer(Cell):
|
|
@@ -80,7 +80,8 @@ class ExponentialDecayLR(LearningRateSchedule):
|
|
|
80
80
|
learning_rate (float): The initial value of learning rate.
|
|
81
81
|
decay_rate (float): The decay rate.
|
|
82
82
|
decay_steps (int): Number of steps to decay over.
|
|
83
|
-
is_stair (bool): If
|
|
83
|
+
is_stair (bool, optional): If ``True``, learning rate is decayed once every `decay_steps` time.
|
|
84
|
+
Default: ``False`` .
|
|
84
85
|
|
|
85
86
|
Inputs:
|
|
86
87
|
- **global_step** (Tensor) - The current step number. :math:`current\_step` in the above formula.
|
|
@@ -223,7 +224,9 @@ class InverseDecayLR(LearningRateSchedule):
|
|
|
223
224
|
learning_rate (float): The initial value of learning rate.
|
|
224
225
|
decay_rate (float): The decay rate.
|
|
225
226
|
decay_steps (int): Number of steps to decay over.
|
|
226
|
-
is_stair (bool): If true, learning rate decay once every `decay_steps` times.
|
|
227
|
+
is_stair (bool, optional): If true, learning rate decay once every `decay_steps` times.
|
|
228
|
+
If False, the learning rate
|
|
229
|
+
decays for every step. Default: ``False`` .
|
|
227
230
|
|
|
228
231
|
Inputs:
|
|
229
232
|
- **global_step** (Tensor) - The current step number.
|
|
@@ -454,8 +457,9 @@ class WarmUpLR(LearningRateSchedule):
|
|
|
454
457
|
tmp\_step= \min(current\_step, warmup\_steps)
|
|
455
458
|
|
|
456
459
|
Args:
|
|
457
|
-
learning_rate (float): The initial value of learning rate.
|
|
458
|
-
warmup_steps (int): The warm up steps of learning rate.
|
|
460
|
+
learning_rate (float): The initial value of learning rate. The value of `learning_rate` must be greater than 0.
|
|
461
|
+
warmup_steps (int): The warm up steps of learning rate. The value of `warmup_steps` must be greater than
|
|
462
|
+
or equal to 1.
|
|
459
463
|
|
|
460
464
|
Inputs:
|
|
461
465
|
- **global_step** (Tensor) - The current step number. Shape is :math:`()`.
|
mindspore/nn/loss/loss.py
CHANGED
|
@@ -24,8 +24,6 @@ from mindspore.common.tensor import Tensor
|
|
|
24
24
|
from mindspore.common.parameter import Parameter
|
|
25
25
|
from mindspore.ops import operations as P
|
|
26
26
|
from mindspore.ops.operations import _inner_ops as inner
|
|
27
|
-
from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp
|
|
28
|
-
from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp
|
|
29
27
|
from mindspore.ops import functional as F
|
|
30
28
|
from mindspore import nn
|
|
31
29
|
from mindspore.ops.primitive import constexpr, _primexpr
|
|
@@ -33,7 +31,6 @@ from mindspore.nn.cell import Cell
|
|
|
33
31
|
from mindspore.nn.layer.activation import get_activation
|
|
34
32
|
from mindspore import _checkparam as validator
|
|
35
33
|
from mindspore import context
|
|
36
|
-
from mindspore.ops.auto_generate import l1_loss_ext_op
|
|
37
34
|
|
|
38
35
|
|
|
39
36
|
class LossBase(Cell):
|
|
@@ -130,7 +127,8 @@ class LossBase(Cell):
|
|
|
130
127
|
Args:
|
|
131
128
|
x (Tensor): Tensor of shape :math:`(N, *)` where :math:`*` means, any number of
|
|
132
129
|
additional dimensions.
|
|
133
|
-
weights (Union[float, Tensor]):
|
|
130
|
+
weights (Union[float, Tensor], optional): Weights. When `weights` is a Tensor,
|
|
131
|
+
the rank is either 0, or the same rank as inputs,
|
|
134
132
|
and must be broadcastable to inputs (i.e., all dimensions must be either `1`,
|
|
135
133
|
or the same as the corresponding inputs dimension). Default: ``1.0`` .
|
|
136
134
|
|
|
@@ -319,7 +317,7 @@ class L1LossExt(LossBase):
|
|
|
319
317
|
self.reduction = reduction
|
|
320
318
|
|
|
321
319
|
def construct(self, logits, labels):
|
|
322
|
-
return l1_loss_ext_op(logits, labels, self.reduction)
|
|
320
|
+
return ops.auto_generate.l1_loss_ext_op(logits, labels, self.reduction)
|
|
323
321
|
|
|
324
322
|
|
|
325
323
|
class MSELoss(LossBase):
|
|
@@ -620,7 +618,8 @@ class MarginRankingLoss(LossBase):
|
|
|
620
618
|
|
|
621
619
|
class SmoothL1Loss(LossBase):
|
|
622
620
|
r"""
|
|
623
|
-
SmoothL1 loss function
|
|
621
|
+
SmoothL1 loss function. Compare the error value element-wise and
|
|
622
|
+
if the absolute error between the predicted value and the target value
|
|
624
623
|
is less than the set threshold `beta`, the square term is used, otherwise the absolute error term is used.
|
|
625
624
|
|
|
626
625
|
Given two input :math:`x,\ y`, the SmoothL1Loss can be described as follows:
|
|
@@ -628,11 +627,11 @@ class SmoothL1Loss(LossBase):
|
|
|
628
627
|
.. math::
|
|
629
628
|
L_{i} =
|
|
630
629
|
\begin{cases}
|
|
631
|
-
\frac{0.5 (x_i - y_i)^{2}}{\beta}, & \text{if } |x_i - y_i| < {
|
|
632
|
-
|x_i - y_i| - 0.5 {\beta}, & \text{otherwise.}
|
|
630
|
+
\frac{0.5 (x_i - y_i)^{2}}{\text{beta}}, & \text{if } |x_i - y_i| < \text{beta} \\
|
|
631
|
+
|x_i - y_i| - 0.5 * {\text{beta}}, & \text{otherwise.}
|
|
633
632
|
\end{cases}
|
|
634
633
|
|
|
635
|
-
Where :math:`{\beta}` represents the threshold `beta`.
|
|
634
|
+
Where :math:`{\text{beta}}` represents the threshold `beta`.
|
|
636
635
|
|
|
637
636
|
If `reduction` is not `none`, then:
|
|
638
637
|
|
|
@@ -653,8 +652,11 @@ class SmoothL1Loss(LossBase):
|
|
|
653
652
|
robust to outliers, and the loss function has better robustness.
|
|
654
653
|
|
|
655
654
|
Args:
|
|
656
|
-
beta (
|
|
657
|
-
Default: ``1.0`` .
|
|
655
|
+
beta (number, optional): The loss function calculates the threshold of the transformation
|
|
656
|
+
between L1Loss and L2Loss. Default: ``1.0`` .
|
|
657
|
+
|
|
658
|
+
- Ascend: The value should be equal to or greater than zero.
|
|
659
|
+
- CPU/GPU: The value should be greater than zero.
|
|
658
660
|
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
659
661
|
``'sum'`` . Default: ``'none'`` .
|
|
660
662
|
|
|
@@ -663,22 +665,28 @@ class SmoothL1Loss(LossBase):
|
|
|
663
665
|
- ``'sum'``: the output elements will be summed.
|
|
664
666
|
|
|
665
667
|
Inputs:
|
|
666
|
-
- **logits** (Tensor) - Predictive value. Tensor of any dimension.
|
|
667
|
-
|
|
668
|
-
|
|
668
|
+
- **logits** (Tensor) - Predictive value. Tensor of any dimension. Supported dtypes:
|
|
669
|
+
|
|
670
|
+
- Ascend: float16, float32, bfloat16.
|
|
671
|
+
- CPU/GPU: float16, float32, float64.
|
|
672
|
+
|
|
673
|
+
- **labels** (Tensor) - Ground truth data.
|
|
674
|
+
|
|
675
|
+
- CPU/Ascend: has the same shape as the `logits`,
|
|
676
|
+
`logits` and `labels` comply with the implicit type conversion rules to make the data types consistent.
|
|
677
|
+
- GPU: has the same shape and dtype as the `logits`.
|
|
669
678
|
|
|
670
679
|
Outputs:
|
|
671
680
|
Tensor, if `reduction` is ``'none'``, then output is a tensor with the same shape as `logits`.
|
|
672
681
|
Otherwise the shape of output tensor is :math:`()`.
|
|
673
682
|
|
|
674
683
|
Raises:
|
|
675
|
-
TypeError: If `
|
|
676
|
-
|
|
677
|
-
TypeError: If `logits` or `labels` are not Tensor.
|
|
678
|
-
TypeError: If dtype of `logits` or `labels` is neither float16 not float32.
|
|
679
|
-
TypeError: If dtype of `logits` is not the same as `labels`.
|
|
680
|
-
ValueError: If `beta` is less than or equal to 0.
|
|
684
|
+
TypeError: If input `logits` or `labels` are not Tensor.
|
|
685
|
+
RuntimeError: If dtype of `logits` or `labels` is not one of float16, float32, float64, bfloat16.
|
|
681
686
|
ValueError: If shape of `logits` is not the same as `labels`.
|
|
687
|
+
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
688
|
+
TypeError: If `beta` is not a float, int or bool.
|
|
689
|
+
RuntimeError: If `beta` is less than or equal to 0.
|
|
682
690
|
|
|
683
691
|
Supported Platforms:
|
|
684
692
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -728,16 +736,19 @@ class SoftMarginLoss(LossBase):
|
|
|
728
736
|
- ``'sum'``: the output elements will be summed.
|
|
729
737
|
|
|
730
738
|
Inputs:
|
|
731
|
-
- **logits** (Tensor) - Predict data. Data type must be float16
|
|
732
|
-
|
|
739
|
+
- **logits** (Tensor) - Predict data. Data type must be float16, float32,
|
|
740
|
+
bfloat16 (Among them, the Atlas training series products do not support bfloat16).
|
|
741
|
+
- **labels** (Tensor) - Ground truth data, with the same shape as `logits`.
|
|
742
|
+
In GE mode, the data type should be the same as `logits`.
|
|
733
743
|
|
|
734
744
|
Outputs:
|
|
735
|
-
Tensor or Scalar, if `reduction` is ``
|
|
745
|
+
Tensor or Scalar, if `reduction` is ``'none'``, its shape is the same as `logits`.
|
|
736
746
|
Otherwise, a scalar value will be returned.
|
|
737
747
|
|
|
738
748
|
Raises:
|
|
739
749
|
TypeError: If `logits` or `labels` is not a Tensor.
|
|
740
|
-
TypeError: If dtype of `logits` or `labels` is
|
|
750
|
+
TypeError: If dtype of `logits` or `labels` is not float16, float32,
|
|
751
|
+
bfloat16 (Among them, the Atlas training series products do not support bfloat16).
|
|
741
752
|
ValueError: If shape of `logits` is not the same as `labels`.
|
|
742
753
|
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
743
754
|
|
|
@@ -758,10 +769,10 @@ class SoftMarginLoss(LossBase):
|
|
|
758
769
|
|
|
759
770
|
def __init__(self, reduction='mean'):
|
|
760
771
|
super(SoftMarginLoss, self).__init__()
|
|
761
|
-
self.
|
|
772
|
+
self.reduction = reduction
|
|
762
773
|
|
|
763
774
|
def construct(self, logits, labels):
|
|
764
|
-
return
|
|
775
|
+
return F.soft_margin_loss(logits, labels, self.reduction)
|
|
765
776
|
|
|
766
777
|
|
|
767
778
|
class SoftmaxCrossEntropyWithLogits(LossBase):
|
|
@@ -809,8 +820,8 @@ class SoftmaxCrossEntropyWithLogits(LossBase):
|
|
|
809
820
|
|
|
810
821
|
Raises:
|
|
811
822
|
TypeError: If `sparse` is not a bool.
|
|
812
|
-
TypeError: If `sparse` is True and dtype of `labels` is neither int32 nor int64.
|
|
813
|
-
TypeError: If `sparse` is False and dtype of `labels` is neither float16 not float32.
|
|
823
|
+
TypeError: If `sparse` is ``True`` and dtype of `labels` is neither int32 nor int64.
|
|
824
|
+
TypeError: If `sparse` is ``False`` and dtype of `labels` is neither float16 not float32.
|
|
814
825
|
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
815
826
|
|
|
816
827
|
Supported Platforms:
|
|
@@ -889,8 +900,8 @@ class DiceLoss(LossBase):
|
|
|
889
900
|
:math:`pred` represent `logits`, :math:`true` represent `labels` .
|
|
890
901
|
|
|
891
902
|
Args:
|
|
892
|
-
smooth (float): A term added to the denominator to improve numerical stability.
|
|
893
|
-
|
|
903
|
+
smooth (float, optional): A term added to the denominator to improve numerical stability.
|
|
904
|
+
Should be greater than 0. Default: ``1e-5`` .
|
|
894
905
|
|
|
895
906
|
Inputs:
|
|
896
907
|
- **logits** (Tensor) - Input predicted value. The data type must be float16 or float32.
|
|
@@ -934,11 +945,12 @@ class DiceLoss(LossBase):
|
|
|
934
945
|
if label.dtype == mstype.uint8:
|
|
935
946
|
raise TypeError(f"For '{self.cls_name}', the dtype of 'labels' can not be uint8.")
|
|
936
947
|
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
|
|
937
|
-
|
|
938
|
-
|
|
948
|
+
unionset_part1 = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1)))
|
|
949
|
+
unionset_part2 = self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
|
|
950
|
+
unionset = ops.add(unionset_part1, unionset_part2)
|
|
939
951
|
|
|
940
|
-
single_dice_coeff = (2 * intersection) / (unionset
|
|
941
|
-
dice_loss = 1
|
|
952
|
+
single_dice_coeff = (2 * intersection) / ops.add(unionset, self.smooth)
|
|
953
|
+
dice_loss = ops.sub(1, single_dice_coeff)
|
|
942
954
|
|
|
943
955
|
return dice_loss
|
|
944
956
|
|
|
@@ -1054,7 +1066,7 @@ class MultiClassDiceLoss(LossBase):
|
|
|
1054
1066
|
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
|
|
1055
1067
|
if self.weights is not None:
|
|
1056
1068
|
_check_weights(self.weights.shape[0], label.shape[1], self.cls_name)
|
|
1057
|
-
dice_loss
|
|
1069
|
+
dice_loss = dice_loss * self.weights[i]
|
|
1058
1070
|
total_loss += dice_loss
|
|
1059
1071
|
|
|
1060
1072
|
return total_loss / label.shape[1]
|
|
@@ -1631,7 +1643,7 @@ class MultiMarginLoss(LossBase):
|
|
|
1631
1643
|
def __init__(self, p=1, margin=1.0, reduction='mean', weight=None):
|
|
1632
1644
|
"""Initialize MultiMarginLoss."""
|
|
1633
1645
|
super(MultiMarginLoss, self).__init__()
|
|
1634
|
-
self.multi_margin_loss =
|
|
1646
|
+
self.multi_margin_loss = ops.MultiMarginLoss(p=p, margin=margin, reduction=reduction)
|
|
1635
1647
|
self.weight = weight
|
|
1636
1648
|
|
|
1637
1649
|
def construct(self, x, target, weight=None):
|
|
@@ -1718,22 +1730,11 @@ class BCELoss(LossBase):
|
|
|
1718
1730
|
def __init__(self, weight=None, reduction='mean'):
|
|
1719
1731
|
"""Initialize BCELoss."""
|
|
1720
1732
|
super(BCELoss, self).__init__(reduction)
|
|
1721
|
-
self.
|
|
1722
|
-
self.
|
|
1723
|
-
if not self.weight_one:
|
|
1724
|
-
self.weight = weight
|
|
1725
|
-
else:
|
|
1726
|
-
self.ones = P.OnesLike()
|
|
1733
|
+
self.reduction = reduction
|
|
1734
|
+
self.weight = weight
|
|
1727
1735
|
|
|
1728
1736
|
def construct(self, logits, labels):
|
|
1729
|
-
|
|
1730
|
-
_check_is_tensor('labels', labels, self.cls_name)
|
|
1731
|
-
if self.weight_one:
|
|
1732
|
-
weight = self.ones(logits)
|
|
1733
|
-
else:
|
|
1734
|
-
weight = self.weight
|
|
1735
|
-
loss = self.binary_cross_entropy(logits, labels, weight)
|
|
1736
|
-
return loss
|
|
1737
|
+
return F.binary_cross_entropy(logits, labels, self.weight, self.reduction)
|
|
1737
1738
|
|
|
1738
1739
|
|
|
1739
1740
|
class CosineEmbeddingLoss(LossBase):
|
|
@@ -1887,7 +1888,7 @@ class MultilabelMarginLoss(LossBase):
|
|
|
1887
1888
|
|
|
1888
1889
|
def __init__(self, reduction='mean'):
|
|
1889
1890
|
super(MultilabelMarginLoss, self).__init__()
|
|
1890
|
-
self.multilabel_margin_loss =
|
|
1891
|
+
self.multilabel_margin_loss = ops.MultilabelMarginLoss(reduction=reduction)
|
|
1891
1892
|
|
|
1892
1893
|
def construct(self, x, target):
|
|
1893
1894
|
loss, _ = self.multilabel_margin_loss(x, target)
|
|
@@ -2265,7 +2266,8 @@ class TripletMarginLoss(LossBase):
|
|
|
2265
2266
|
- ``'mean'``: compute and return the mean of elements in the output.
|
|
2266
2267
|
- ``'sum'``: the output elements will be summed.
|
|
2267
2268
|
|
|
2268
|
-
margin (Union[Tensor, float]): Make a margin between the positive pair and the negative pair.
|
|
2269
|
+
margin (Union[Tensor, float]): Make a margin between the positive pair and the negative pair. The length of
|
|
2270
|
+
shape of `margin` must be 0.
|
|
2269
2271
|
Default: ``1.0`` .
|
|
2270
2272
|
|
|
2271
2273
|
Inputs:
|
|
@@ -2275,7 +2277,8 @@ class TripletMarginLoss(LossBase):
|
|
|
2275
2277
|
shape as `x`. :math:`p` in the above formula.
|
|
2276
2278
|
- **negative** (Tensor) - A sample belonging to the different class from `x`, with the same type and shape
|
|
2277
2279
|
as `x`. :math:`n` in the above formula.
|
|
2278
|
-
- **margin** (Union[Tensor, float]) - Make a margin between the positive pair and the negative pair.
|
|
2280
|
+
- **margin** (Union[Tensor, float]) - Make a margin between the positive pair and the negative pair. The length
|
|
2281
|
+
of shape of `margin` must be 0.
|
|
2279
2282
|
Default: ``1.0`` .
|
|
2280
2283
|
|
|
2281
2284
|
Outputs:
|
|
@@ -2576,7 +2579,7 @@ class KLDivLoss(LossBase):
|
|
|
2576
2579
|
the updating formulas of KLDivLoss algorithm are as follows,
|
|
2577
2580
|
|
|
2578
2581
|
.. math::
|
|
2579
|
-
L(x, target) = target \cdot (\log target - x)
|
|
2582
|
+
L(x, target) = target \cdot (\log target - \log x)
|
|
2580
2583
|
|
|
2581
2584
|
Then,
|
|
2582
2585
|
|
|
@@ -2870,7 +2873,7 @@ class HingeEmbeddingLoss(LossBase):
|
|
|
2870
2873
|
where :math:`L = \{l_1,\dots,l_N\}^\top`.
|
|
2871
2874
|
|
|
2872
2875
|
Args:
|
|
2873
|
-
margin (float, int): Threshold defined by Hinge Embedding Loss :math:`margin`.
|
|
2876
|
+
margin (float, int, optional): Threshold defined by Hinge Embedding Loss :math:`margin`.
|
|
2874
2877
|
Represented as :math:`\Delta` in the formula. Default: ``1.0`` .
|
|
2875
2878
|
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2876
2879
|
``'sum'`` . Default: ``'mean'`` .
|
mindspore/nn/optim/ada_grad.py
CHANGED
|
@@ -78,6 +78,7 @@ class Adagrad(Optimizer):
|
|
|
78
78
|
:math:`state\_sum` stands for the accumulated squared sum of the gradients :math:`accum`.
|
|
79
79
|
:math:`g` stands for `grads`, :math:`\lambda` stands for `weight_decay`.
|
|
80
80
|
:math:`\gamma` stands for `learning_rate`, :math:`w` stands for `params`.
|
|
81
|
+
:math:`t` represents current `step`.
|
|
81
82
|
|
|
82
83
|
Note:
|
|
83
84
|
If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
|
|
@@ -112,8 +113,8 @@ class Adagrad(Optimizer):
|
|
|
112
113
|
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
|
|
113
114
|
one group of `params`.
|
|
114
115
|
|
|
115
|
-
accum (float): The starting value for :math:`h`, must be zero or positive values. Default: ``0.1`` .
|
|
116
|
-
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` .
|
|
116
|
+
accum (float, optional): The starting value for :math:`h`, must be zero or positive values. Default: ``0.1`` .
|
|
117
|
+
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``0.001`` .
|
|
117
118
|
|
|
118
119
|
- float: The fixed learning rate value. Must be equal to or greater than 0.
|
|
119
120
|
|
|
@@ -129,13 +130,14 @@ class Adagrad(Optimizer):
|
|
|
129
130
|
<https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
|
|
130
131
|
with step as the input to get the learning rate of current step.
|
|
131
132
|
|
|
132
|
-
update_slots (bool): Whether the :math:`h` will be updated. Default: ``True`` .
|
|
133
|
-
loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general,
|
|
133
|
+
update_slots (bool, optional): Whether the :math:`h` will be updated. Default: ``True`` .
|
|
134
|
+
loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
|
|
135
|
+
use the default value.
|
|
134
136
|
Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
|
|
135
137
|
`FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
|
|
136
138
|
`FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
|
|
137
139
|
Default: ``1.0`` .
|
|
138
|
-
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
|
|
140
|
+
weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
|
|
139
141
|
|
|
140
142
|
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
|
141
143
|
|
mindspore/nn/optim/adadelta.py
CHANGED
|
@@ -68,8 +68,8 @@ class Adadelta(Optimizer):
|
|
|
68
68
|
|
|
69
69
|
Args:
|
|
70
70
|
params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
|
|
71
|
-
`params` is a list of `dict`, the string "params"
|
|
72
|
-
"order_params" are the keys can be parsed.
|
|
71
|
+
`params` is a list of `dict`, the string `"params"`, `"lr"`, `"weight_decay"`, `"grad_centralization"` and
|
|
72
|
+
`"order_params"` are the keys can be parsed.
|
|
73
73
|
|
|
74
74
|
- params: Required. Parameters in current group. The value must be a list of `Parameter`.
|
|
75
75
|
|
|
@@ -93,7 +93,7 @@ class Adadelta(Optimizer):
|
|
|
93
93
|
If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
|
|
94
94
|
one group of `params`.
|
|
95
95
|
|
|
96
|
-
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``1.0`` .
|
|
96
|
+
learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule], optional): Default: ``1.0`` .
|
|
97
97
|
|
|
98
98
|
- float: The fixed learning rate value. Must be equal to or greater than 0.
|
|
99
99
|
|
|
@@ -109,14 +109,16 @@ class Adadelta(Optimizer):
|
|
|
109
109
|
<https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
|
|
110
110
|
with step as the input to get the learning rate of current step.
|
|
111
111
|
|
|
112
|
-
rho (float): Decay rate, must be in range [0.0, 1.0]. Default: ``0.9`` .
|
|
113
|
-
epsilon (float): A small value added for numerical stability, must be non-negative.
|
|
114
|
-
|
|
112
|
+
rho (float, optional): Decay rate, must be in range [0.0, 1.0]. Default: ``0.9`` .
|
|
113
|
+
epsilon (float, optional): A small value added for numerical stability, must be non-negative.
|
|
114
|
+
Default: ``1e-6`` .
|
|
115
|
+
loss_scale (float, optional): Value for the loss scale. It must be greater than 0.0. In general,
|
|
116
|
+
use the default value.
|
|
115
117
|
Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
|
|
116
118
|
`FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
|
|
117
119
|
`FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
|
|
118
120
|
Default: ``1.0`` .
|
|
119
|
-
weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
|
|
121
|
+
weight_decay (Union[float, int, Cell], optional): Weight decay (L2 penalty). Default: ``0.0`` .
|
|
120
122
|
|
|
121
123
|
- float: The fixed weight decay value. Must be equal to or greater than 0.
|
|
122
124
|
|
|
@@ -134,9 +136,9 @@ class Adadelta(Optimizer):
|
|
|
134
136
|
|
|
135
137
|
Raises:
|
|
136
138
|
TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
|
|
137
|
-
TypeError: If element of `
|
|
139
|
+
TypeError: If element of `params` is neither Parameter nor dict.
|
|
138
140
|
TypeError: If `rho`, `epsilon` or `loss_scale` is not a float.
|
|
139
|
-
TypeError: If `weight_decay` is
|
|
141
|
+
TypeError: If `weight_decay` is not float, int or cell.
|
|
140
142
|
ValueError: if `rho` is not in range [0.0, 1.0].
|
|
141
143
|
ValueError: If `loss_scale` is less than or equal to 0.
|
|
142
144
|
ValueError: If `learning_rate`, `epsilon` or `weight_decay` is less than 0.
|