mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -104,7 +104,242 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
|
|
104
104
|
return mint.nn.functional.adaptive_avg_pool2d(input, self.output_size)
|
|
105
105
|
|
|
106
106
|
|
|
107
|
+
class AdaptiveAvgPool3d(Cell):
|
|
108
|
+
r"""
|
|
109
|
+
This operator applies a 3D adaptive average pooling to an input signal composed of multiple input planes.
|
|
110
|
+
That is, for any input size, the size of the specified output is :math:`(D, H, W)`.
|
|
111
|
+
The number of output features is equal to the number of input planes.
|
|
112
|
+
|
|
113
|
+
Suppose the last 3 dimension size of input is :math:`(inD, inH, inW)`, then the last 3 dimension size of output is
|
|
114
|
+
:math:`(outD, outH, outW)`.
|
|
115
|
+
|
|
116
|
+
.. math::
|
|
117
|
+
\begin{array}{ll} \\
|
|
118
|
+
\forall \quad od \in [0,outD-1], oh \in [0,outH-1], ow \in [0,outW-1]\\
|
|
119
|
+
output[od,oh,ow] = \\
|
|
120
|
+
\qquad mean(input[istartD:iendD+1,istartH:iendH+1,istartW:iendW+1])\\
|
|
121
|
+
where,\\
|
|
122
|
+
\qquad istartD= \left\lceil \frac{od * inD}{outD} \right\rceil \\
|
|
123
|
+
\qquad iendD=\left\lfloor \frac{(od+1)* inD}{outD} \right\rfloor \\
|
|
124
|
+
\qquad istartH=\left\lceil \frac{oh * inH}{outH} \right\rceil \\
|
|
125
|
+
\qquad iendH=\left\lfloor \frac{(oh+1) * inH}{outH} \right\rfloor \\
|
|
126
|
+
\qquad istartW=\left\lceil \frac{ow * inW}{outW} \right\rceil \\
|
|
127
|
+
\qquad iendW=\left\lfloor \frac{(ow+1) * inW}{outW} \right\rfloor
|
|
128
|
+
\end{array}
|
|
129
|
+
|
|
130
|
+
.. warning::
|
|
131
|
+
For Ascend, it is only supported on Atlas A2 Training Series Products.
|
|
132
|
+
This is an experimental optimizer API that is subject to change or deletion.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
output_size (Union[int, tuple]): The target output size. `output_size` can be a tuple :math:`(D, H, W)`,
|
|
136
|
+
or an int D for :math:`(D, D, D)`. :math:`D`, :math:`H` and :math:`W` can be int or None
|
|
137
|
+
which means the output size is the same as that of the input.
|
|
138
|
+
|
|
139
|
+
Inputs:
|
|
140
|
+
- **input** (Tensor) - The input of AdaptiveAvgPool3d, which is a 5D or 4D Tensor.
|
|
141
|
+
|
|
142
|
+
Outputs:
|
|
143
|
+
Tensor, with the same type as the `input`.
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
TypeError: If `input` is not a Tensor.
|
|
147
|
+
ValueError: If the dimension of `input` is not 4D or 5D.
|
|
148
|
+
ValueError: If `output_size` value is not positive.
|
|
149
|
+
|
|
150
|
+
Supported Platforms:
|
|
151
|
+
``Ascend``
|
|
152
|
+
|
|
153
|
+
Examples:
|
|
154
|
+
>>> import mindspore as ms
|
|
155
|
+
>>> from mindspore import mint
|
|
156
|
+
>>> import numpy as np
|
|
157
|
+
>>> # case 1: output_size=(3, 3, 4)
|
|
158
|
+
>>> output_size=(3, 3, 4)
|
|
159
|
+
>>> input_x_val = np.random.randn(4, 3, 5, 6, 7)
|
|
160
|
+
>>> input_x = ms.Tensor(input_x_val, ms.float32)
|
|
161
|
+
>>> net = mint.nn.AdaptiveAvgPool3d(output_size)
|
|
162
|
+
>>> output = net(input_x)
|
|
163
|
+
>>> print(output.shape)
|
|
164
|
+
(4, 3, 3, 3, 4)
|
|
165
|
+
>>> # case 2: output_size=4
|
|
166
|
+
>>> output_size=5
|
|
167
|
+
>>> input_x_val = np.random.randn(2, 3, 8, 6, 12)
|
|
168
|
+
>>> input_x = ms.Tensor(input_x_val, ms.float32)
|
|
169
|
+
>>> net = mint.nn.AdaptiveAvgPool3d(output_size)
|
|
170
|
+
>>> output = net(input_x)
|
|
171
|
+
>>> print(output.shape)
|
|
172
|
+
(2, 3, 5, 5, 5)
|
|
173
|
+
>>> # case 3: output_size=(None, 4, 5)
|
|
174
|
+
>>> output_size=(None, 4, 5)
|
|
175
|
+
>>> input_x_val = np.random.randn(4, 1, 9, 10, 8)
|
|
176
|
+
>>> input_x = ms.Tensor(input_x_val, ms.float32)
|
|
177
|
+
>>> net = mint.nn.AdaptiveAvgPool3d(output_size)
|
|
178
|
+
>>> output = net(input_x)
|
|
179
|
+
>>> print(output.shape)
|
|
180
|
+
(4, 1, 9, 4, 5)
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(self, output_size):
|
|
184
|
+
"""Initialize AdaptiveAvgPool3d."""
|
|
185
|
+
super(AdaptiveAvgPool3d, self).__init__()
|
|
186
|
+
self.output_size = output_size
|
|
187
|
+
|
|
188
|
+
def construct(self, input):
|
|
189
|
+
return mint.nn.functional.adaptive_avg_pool3d(input, self.output_size)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class MaxUnpool2d(Cell):
|
|
193
|
+
r"""
|
|
194
|
+
Computes the inverse of `Maxpool2d`.
|
|
195
|
+
|
|
196
|
+
`MaxUnpool2d` keeps the maximal value and set all position of non-maximal values to zero.
|
|
197
|
+
Typically the input is of shape :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`,
|
|
198
|
+
and the output is of shape :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`.
|
|
199
|
+
The operation is as follows.
|
|
200
|
+
|
|
201
|
+
.. math::
|
|
202
|
+
\begin{array}{ll} \\
|
|
203
|
+
H_{out} = (H_{in} - 1) \times stride[0] - 2 \times padding[0] + kernel\_size[0] \\
|
|
204
|
+
W_{out} = (W_{in} - 1) \times stride[1] - 2 \times padding[1] + kernel\_size[1] \\
|
|
205
|
+
\end{array}
|
|
206
|
+
|
|
207
|
+
.. warning::
|
|
208
|
+
This is an experimental API that is subject to change or deletion.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
|
|
212
|
+
an int number that represents height and width of the kernel,
|
|
213
|
+
or a tuple of two int numbers that represent height and width respectively.
|
|
214
|
+
stride (Union[int, tuple[int]], optional): The distance of kernel moving,
|
|
215
|
+
an int number that represents the height and width of movement are both stride,
|
|
216
|
+
or a tuple of two int numbers that represent height and width of movement respectively.
|
|
217
|
+
Default: ``None`` , which indicates the moving step is `kernel_size` .
|
|
218
|
+
padding (Union[int, tuple[int]], optional): The pad value to be filled. Default: ``0`` .
|
|
219
|
+
If `padding` is an integer, the paddings of height and width are the same, equal to padding.
|
|
220
|
+
If `padding` is a tuple of two integers, the padding of height and width equal to padding[0]
|
|
221
|
+
and padding[1] correspondingly.
|
|
222
|
+
|
|
223
|
+
Inputs:
|
|
224
|
+
- **input** (Tensor) - The input Tensor to invert.
|
|
225
|
+
Tensor of shape :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
|
226
|
+
- **indices** (Tensor) - Max values' index represented by the indices.
|
|
227
|
+
Tensor of shape must be same with input 'input'.
|
|
228
|
+
Values of indices must belong to :math:`[0, H_{in} \times W_{in} - 1]`.
|
|
229
|
+
Data type must be in int32 or int64.
|
|
230
|
+
- **output_size** (tuple[int], optional) - The target output size. Default: ``None`` .
|
|
231
|
+
If output_size == (), then the shape of output computed by `kernel_size`, `stride` and `padding`.
|
|
232
|
+
If output_size != (), then output_size must be :math:`(N, C, H, W)` , :math:`(C, H, W)` or :math:`(H, W)`
|
|
233
|
+
and output_size must belong to
|
|
234
|
+
:math:`[(N, C, H_{out} - stride[0], W_{out} - stride[1]), (N, C, H_{out} + stride[0], W_{out} + stride[1])]`.
|
|
235
|
+
|
|
236
|
+
Outputs:
|
|
237
|
+
Tensor, with shape :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`,
|
|
238
|
+
with the same data type with `input`.
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
TypeError: If data type of `input` or `indices` is not supported.
|
|
242
|
+
TypeError: If `kernel_size`, `stride` or `padding` is neither an int nor a tuple.
|
|
243
|
+
ValueError: If numbers in `stride`, `padding` or `kernel_size` is not positive.
|
|
244
|
+
ValueError: If the shapes of `input` and `indices` are not equal.
|
|
245
|
+
ValueError: If `input` whose length is not 3 or 4.
|
|
246
|
+
ValueError: If `output_size` whose type is not tuple.
|
|
247
|
+
ValueError: If `output_size` is not close to output size computed by attr `kernel_size`, `stride`, `padding`.
|
|
248
|
+
|
|
249
|
+
Supported Platforms:
|
|
250
|
+
``Ascend``
|
|
251
|
+
|
|
252
|
+
Examples:
|
|
253
|
+
>>> import numpy as np
|
|
254
|
+
>>> from mindspore import Tensor, mint
|
|
255
|
+
>>> input = Tensor(np.array([[[[0, 1], [8, 9]]]]).astype(np.float32))
|
|
256
|
+
>>> indices = Tensor(np.array([[[[0, 1], [2, 3]]]]).astype(np.int64))
|
|
257
|
+
>>> net = mint.nn.MaxUnpool2d(1, stride=1, padding=0)
|
|
258
|
+
>>> output = net(input, indices)
|
|
259
|
+
>>> print(output.asnumpy())
|
|
260
|
+
[[[[0. 1.]
|
|
261
|
+
[8. 9.]]]]
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
def __init__(self, kernel_size, stride=None, padding=0) -> None:
|
|
265
|
+
super(MaxUnpool2d, self).__init__()
|
|
266
|
+
self.kernel_size = kernel_size
|
|
267
|
+
self.stride = stride
|
|
268
|
+
self.padding = padding
|
|
269
|
+
|
|
270
|
+
def construct(self, input, indices, output_size=None):
|
|
271
|
+
return mint.nn.functional.max_unpool2d(input, indices,
|
|
272
|
+
self.kernel_size, self.stride,
|
|
273
|
+
self.padding, output_size)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class _AdaptiveMaxPoolNd(Cell):
|
|
277
|
+
"""Common base of AdaptiveMaxPool1d"""
|
|
278
|
+
|
|
279
|
+
def __init__(self, output_size, return_indices=False) -> None:
|
|
280
|
+
super(_AdaptiveMaxPoolNd, self).__init__()
|
|
281
|
+
self.output_size = output_size
|
|
282
|
+
self.return_indices = return_indices
|
|
283
|
+
|
|
284
|
+
def extend_repr(self):
|
|
285
|
+
return 'output_size={}, return_indices={}'.format(self.output_size, self.return_indices)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
|
|
289
|
+
r"""
|
|
290
|
+
Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
|
291
|
+
|
|
292
|
+
The output is of size :math:`L_{out}` , for any input size.
|
|
293
|
+
The number of output features is equal to the number of input planes.
|
|
294
|
+
|
|
295
|
+
.. warning::
|
|
296
|
+
This is an experimental API that is subject to change or deletion.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
output_size (Union[int, tuple]): the target output size :math:`L_{out}` .
|
|
300
|
+
return_indices (bool, optional): Whether to return the index of the maximum value. Default: ``False`` .
|
|
301
|
+
|
|
302
|
+
Inputs:
|
|
303
|
+
- **input** (Tensor) - The input with shape :math:`(N, C, L_{in})` or :math:`(C, L_{in})` .
|
|
304
|
+
|
|
305
|
+
Outputs:
|
|
306
|
+
Union(Tensor, tuple(Tensor, Tensor)).
|
|
307
|
+
|
|
308
|
+
- If `return_indices` is False, output is a Tensor, with shape :math:`(N, C, L_{out})`. It has the same data
|
|
309
|
+
type as `x`.
|
|
310
|
+
- If `return_indices` is True, output is a Tuple of 2 Tensors, representing the result and where the max
|
|
311
|
+
values are generated.
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
TypeError: If `input` is not a tensor.
|
|
315
|
+
TypeError: If dtype of `input` is not float16, float32 or float64.
|
|
316
|
+
TypeError: If `output_size` is not int or tuple.
|
|
317
|
+
TypeError: If `return_indices` is not a bool.
|
|
318
|
+
ValueError: If `output_size` is a tuple and the length of `output_size` is not 1.
|
|
319
|
+
|
|
320
|
+
Supported Platforms:
|
|
321
|
+
``Ascend``
|
|
322
|
+
|
|
323
|
+
Examples:
|
|
324
|
+
>>> import mindspore
|
|
325
|
+
>>> from mindspore import Tensor, mint
|
|
326
|
+
>>> import numpy as np
|
|
327
|
+
>>> input = Tensor(np.array([[[2, 1, 2], [2, 3, 5]]]), mindspore.float16)
|
|
328
|
+
>>> net = mint.nn.AdaptiveMaxPool1d(3)
|
|
329
|
+
>>> output = net(input)
|
|
330
|
+
>>> print(output)
|
|
331
|
+
[[[2. 1. 2.]
|
|
332
|
+
[2. 3. 5.]]]
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def construct(self, input):
|
|
336
|
+
return mint.nn.functional.adaptive_max_pool1d(input, self.output_size, self.return_indices)
|
|
337
|
+
|
|
338
|
+
|
|
107
339
|
__all__ = [
|
|
340
|
+
'AdaptiveAvgPool3d',
|
|
108
341
|
'AdaptiveAvgPool2d',
|
|
109
342
|
'AdaptiveAvgPool1d',
|
|
343
|
+
'AdaptiveMaxPool1d',
|
|
344
|
+
'MaxUnpool2d',
|
|
110
345
|
]
|
mindspore/mint/optim/__init__.py
CHANGED
|
@@ -20,5 +20,7 @@ The optimizer is used to calculate and update the gradients.
|
|
|
20
20
|
"""
|
|
21
21
|
from __future__ import absolute_import
|
|
22
22
|
from mindspore.mint.optim.adamw import AdamW
|
|
23
|
+
from mindspore.mint.optim.adam import Adam
|
|
24
|
+
from mindspore.mint.optim.sgd import SGD
|
|
23
25
|
|
|
24
|
-
__all__ = ['AdamW']
|
|
26
|
+
__all__ = ['AdamW', 'Adam', 'SGD']
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Adam"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
from mindspore.ops import functional as F, composite as C, operations as P
|
|
19
|
+
from mindspore.common.parameter import Parameter
|
|
20
|
+
from mindspore.common.tensor import Tensor
|
|
21
|
+
from mindspore.common import dtype as mstype
|
|
22
|
+
from mindspore.experimental.optim.optimizer import Optimizer
|
|
23
|
+
from mindspore import _checkparam as validator
|
|
24
|
+
from mindspore import mint
|
|
25
|
+
|
|
26
|
+
_optim_adamw_opt = C.MultitypeFuncGraph("optim_adamw_opt")
|
|
27
|
+
hyper_map = C.HyperMap()
|
|
28
|
+
assign_add = P.AssignAdd()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@_optim_adamw_opt.register("Float", "Float", "Float", "Tensor", "Tensor", "Tensor", "Tensor",
|
|
32
|
+
"Tensor", "Tensor", "Tensor")
|
|
33
|
+
def _run_optim_adamw_amsgrad_opt(beta1, beta2, eps, neg_step_size, sqrt_bias_correction2, parameters, grads, exp_avg,
|
|
34
|
+
exp_avg_sq, max_exp_avg_sq):
|
|
35
|
+
"""Apply adam optimizer to the weight parameter when amsgrad is True."""
|
|
36
|
+
success = True
|
|
37
|
+
exp_avg_tmp = mint.add(mint.mul(exp_avg, beta1), grads, alpha=1 - beta1)
|
|
38
|
+
exp_avg_sq_tmp = mint.mul(exp_avg_sq, beta2) + mint.mul(mint.mul(grads, grads), 1 - beta2)
|
|
39
|
+
|
|
40
|
+
max_exp_avg_sq = mint.maximum(max_exp_avg_sq, exp_avg_sq_tmp)
|
|
41
|
+
denom = F.cast(mint.div(mint.sqrt(max_exp_avg_sq), sqrt_bias_correction2), max_exp_avg_sq.dtype)
|
|
42
|
+
denom = mint.add(denom, eps)
|
|
43
|
+
|
|
44
|
+
delta_param = mint.mul(F.cast(neg_step_size, max_exp_avg_sq.dtype), mint.div(exp_avg_tmp, denom))
|
|
45
|
+
F.assign(exp_avg, exp_avg_tmp)
|
|
46
|
+
F.assign(exp_avg_sq, exp_avg_sq_tmp)
|
|
47
|
+
assign_add(parameters, delta_param)
|
|
48
|
+
return success
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@_optim_adamw_opt.register("Float", "Float", "Float", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
|
52
|
+
def _run_optim_adamw_opt(beta1, beta2, eps, neg_step_size, sqrt_bias_correction2, parameters, grads, exp_avg,
|
|
53
|
+
exp_avg_sq):
|
|
54
|
+
"""Apply adam optimizer to the weight parameter when amsgrad is False."""
|
|
55
|
+
success = True
|
|
56
|
+
exp_avg_tmp = mint.add(mint.mul(exp_avg, beta1), grads, alpha=1 - beta1)
|
|
57
|
+
exp_avg_sq_tmp = mint.mul(exp_avg_sq, beta2) + mint.mul(mint.mul(grads, grads), 1 - beta2)
|
|
58
|
+
|
|
59
|
+
denom = F.cast(mint.div(mint.sqrt(exp_avg_sq_tmp), sqrt_bias_correction2), exp_avg_sq_tmp.dtype)
|
|
60
|
+
denom = mint.add(denom, eps)
|
|
61
|
+
|
|
62
|
+
delta_param = mint.mul(F.cast(neg_step_size, exp_avg_sq_tmp.dtype), mint.div(exp_avg_tmp, denom))
|
|
63
|
+
F.assign(exp_avg, exp_avg_tmp)
|
|
64
|
+
F.assign(exp_avg_sq, exp_avg_sq_tmp)
|
|
65
|
+
assign_add(parameters, delta_param)
|
|
66
|
+
return success
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _check_param_value(betas, eps, weight_decay, lr, amsgrad, maximize, prim_name):
|
|
70
|
+
"""Check the type of inputs."""
|
|
71
|
+
validator.check_value_type('betas', betas, [tuple], prim_name)
|
|
72
|
+
validator.check("betas size", len(betas), "", [2], validator.IN, prim_name)
|
|
73
|
+
validator.check_value_type("betas[0]", betas[0], [float], prim_name)
|
|
74
|
+
validator.check_value_type("betas[1]", betas[1], [float], prim_name)
|
|
75
|
+
validator.check_value_type("eps", eps, [float], prim_name)
|
|
76
|
+
validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
|
|
77
|
+
validator.check_value_type("lr", lr, [float], prim_name)
|
|
78
|
+
validator.check_value_type("amsgrad", amsgrad, [bool], prim_name)
|
|
79
|
+
validator.check_value_type("maximize", maximize, [bool], prim_name)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Adam(Optimizer):
|
|
83
|
+
r"""
|
|
84
|
+
Implements Adaptive Moment Estimation (Adam) algorithm.
|
|
85
|
+
|
|
86
|
+
The updating formulas are as follows:
|
|
87
|
+
|
|
88
|
+
.. math::
|
|
89
|
+
\begin{aligned}
|
|
90
|
+
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
|
91
|
+
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
|
|
92
|
+
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
|
|
93
|
+
\:\textit{maximize} \\
|
|
94
|
+
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
|
95
|
+
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
|
|
96
|
+
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
97
|
+
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
|
98
|
+
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
99
|
+
&\hspace{5mm}\textbf{else} \\
|
|
100
|
+
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
101
|
+
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
|
102
|
+
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
|
103
|
+
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
|
104
|
+
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
|
105
|
+
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
|
106
|
+
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
|
107
|
+
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
|
108
|
+
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
|
109
|
+
\widehat{v_t}) \\
|
|
110
|
+
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
|
111
|
+
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
|
112
|
+
&\hspace{5mm}\textbf{else} \\
|
|
113
|
+
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
|
114
|
+
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
|
115
|
+
&\bf{return} \: \theta_t \\[-1.ex]
|
|
116
|
+
\end{aligned}
|
|
117
|
+
|
|
118
|
+
.. warning::
|
|
119
|
+
This is an experimental API that is subject to change or deletion.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
|
|
123
|
+
parameter groups
|
|
124
|
+
lr (Union[int, float, Tensor], optional): learning rate. Default: ``1e-3``.
|
|
125
|
+
betas (Tuple[float, float], optional): The exponential decay rate for the moment estimations.
|
|
126
|
+
Should be in range (0.0, 1.0). Default: ``(0.9, 0.999)``.
|
|
127
|
+
eps (float, optional): term added to the denominator to improve
|
|
128
|
+
numerical stability. Should be greater than 0. Default: ``1e-8``.
|
|
129
|
+
weight_decay (float, optional): weight decay (L2 penalty). Default: ``0.``.
|
|
130
|
+
amsgrad (bool, optional): whether to use the AMSGrad algorithm. Default: ``False``.
|
|
131
|
+
|
|
132
|
+
Keyword Args:
|
|
133
|
+
maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
|
|
134
|
+
Default: ``False``.
|
|
135
|
+
|
|
136
|
+
Inputs:
|
|
137
|
+
- **gradients** (tuple[Tensor]) - The gradients of `params`.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If the `lr` is not int, float or Tensor.
|
|
141
|
+
ValueError: If the `lr` is less than 0.
|
|
142
|
+
ValueError: If the `eps` is less than 0.0.
|
|
143
|
+
ValueError: If the `betas` is not in the range of [0, 1).
|
|
144
|
+
ValueError: If the `weight_decay` is less than 0.
|
|
145
|
+
|
|
146
|
+
Supported Platforms:
|
|
147
|
+
``Ascend``
|
|
148
|
+
|
|
149
|
+
Examples:
|
|
150
|
+
>>> import mindspore
|
|
151
|
+
>>> from mindspore import nn
|
|
152
|
+
>>> from mindspore import mint
|
|
153
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
154
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
155
|
+
>>> net = LeNet5()
|
|
156
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
157
|
+
>>> optimizer = mint.optim.Adam(net.trainable_params(), lr=0.1)
|
|
158
|
+
>>> def forward_fn(data, label):
|
|
159
|
+
... logits = net(data)
|
|
160
|
+
... loss = loss_fn(logits, label)
|
|
161
|
+
... return loss, logits
|
|
162
|
+
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
163
|
+
>>> def train_step(data, label):
|
|
164
|
+
... (loss, _), grads = grad_fn(data, label)
|
|
165
|
+
... optimizer(grads)
|
|
166
|
+
... return loss
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
|
170
|
+
weight_decay=0.0, amsgrad=False, *, maximize=False):
|
|
171
|
+
_check_param_value(betas, eps, weight_decay, lr, amsgrad, maximize, self.cls_name)
|
|
172
|
+
if lr < 0.0:
|
|
173
|
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
174
|
+
if eps < 0.0:
|
|
175
|
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
176
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
177
|
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
178
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
179
|
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
180
|
+
if weight_decay < 0.0:
|
|
181
|
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
|
182
|
+
|
|
183
|
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
184
|
+
weight_decay=weight_decay, amsgrad=amsgrad,
|
|
185
|
+
maximize=maximize)
|
|
186
|
+
self.max_v_group = True
|
|
187
|
+
super(Adam, self).__init__(params, defaults)
|
|
188
|
+
|
|
189
|
+
self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
|
|
190
|
+
self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
|
|
191
|
+
self.state_step = Parameter(Tensor([0], mstype.float32), "state_step")
|
|
192
|
+
self.increase_tensor = Tensor(1, mstype.float32)
|
|
193
|
+
self.assignadd = P.AssignAdd()
|
|
194
|
+
self.pow = P.Pow()
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def construct(self, gradients):
|
|
198
|
+
self.assignadd(self.state_step, self.increase_tensor)
|
|
199
|
+
for group_id, group in enumerate(self.param_groups):
|
|
200
|
+
beta1, beta2 = group['betas']
|
|
201
|
+
maximize = group.get("maximize")
|
|
202
|
+
start_id = self.group_start_id[group_id]
|
|
203
|
+
end_id = self.group_start_id[group_id + 1]
|
|
204
|
+
lr = group.get("lr")
|
|
205
|
+
grads = tuple([grad if not maximize else mint.neg(grad) for grad in gradients[start_id: end_id]])
|
|
206
|
+
|
|
207
|
+
bias_correction1 = 1 - beta1 ** self.state_step
|
|
208
|
+
bias_correction2 = 1 - beta2 ** self.state_step
|
|
209
|
+
neg_step_size = -mint.div(lr, bias_correction1)
|
|
210
|
+
sqrt_bias_correction2 = mint.sqrt(bias_correction2)
|
|
211
|
+
grads = self._decay_weight(group.get("weight_decay"), self.parameters[start_id: end_id], grads)
|
|
212
|
+
|
|
213
|
+
if group.get("amsgrad"):
|
|
214
|
+
self.hyper_map(F.partial(_optim_adamw_opt, beta1, beta2, group.get("eps"), neg_step_size,
|
|
215
|
+
sqrt_bias_correction2),
|
|
216
|
+
self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
|
|
217
|
+
self.exp_avg_sq[start_id: end_id], group.get("max_exp_avg_sq"))
|
|
218
|
+
else:
|
|
219
|
+
self.hyper_map(F.partial(_optim_adamw_opt, beta1, beta2, group.get("eps"), neg_step_size,
|
|
220
|
+
sqrt_bias_correction2),
|
|
221
|
+
self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
|
|
222
|
+
self.exp_avg_sq[start_id: end_id])
|
|
223
|
+
return True
|
mindspore/mint/optim/adamw.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""adamw"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
from mindspore.ops import functional as F, composite as C
|
|
18
|
+
from mindspore.ops import functional as F, composite as C
|
|
19
19
|
from mindspore.common.parameter import Parameter
|
|
20
20
|
from mindspore.common.tensor import Tensor
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
@@ -62,7 +62,9 @@ class AdamW(Optimizer):
|
|
|
62
62
|
Implements Adam Weight Decay algorithm.
|
|
63
63
|
|
|
64
64
|
.. math::
|
|
65
|
-
\begin{
|
|
65
|
+
\begin{array}{l}
|
|
66
|
+
&\newline
|
|
67
|
+
&\hline \\
|
|
66
68
|
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
|
|
67
69
|
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
|
|
68
70
|
\: \epsilon \text{ (epsilon)} \\
|
|
@@ -70,26 +72,32 @@ class AdamW(Optimizer):
|
|
|
70
72
|
\: \textit{maximize} \\
|
|
71
73
|
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
|
|
72
74
|
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
|
|
75
|
+
&\newline
|
|
76
|
+
&\hline \\
|
|
73
77
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
74
|
-
&\hspace{
|
|
75
|
-
&\hspace{
|
|
76
|
-
&\hspace{
|
|
77
|
-
&\hspace{
|
|
78
|
-
&\hspace{
|
|
79
|
-
&\hspace{
|
|
80
|
-
&\hspace{
|
|
81
|
-
&\hspace{
|
|
82
|
-
&\hspace{
|
|
83
|
-
&\hspace{
|
|
84
|
-
&\hspace{
|
|
78
|
+
&\hspace{6mm}\textbf{if} \: \textit{maximize}: \\
|
|
79
|
+
&\hspace{11mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
80
|
+
&\hspace{6mm}\textbf{else} \\
|
|
81
|
+
&\hspace{11mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
82
|
+
&\hspace{6mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
|
|
83
|
+
&\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
|
84
|
+
&\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
|
85
|
+
&\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
|
86
|
+
&\hspace{6mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
|
87
|
+
&\hspace{6mm}\textbf{if} \: amsgrad \\
|
|
88
|
+
&\hspace{11mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
|
85
89
|
\widehat{v_t}) \\
|
|
86
|
-
&\hspace{
|
|
90
|
+
&\hspace{11mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
|
87
91
|
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
|
88
|
-
&\hspace{
|
|
89
|
-
&\hspace{
|
|
92
|
+
&\hspace{6mm}\textbf{else} \\
|
|
93
|
+
&\hspace{11mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
|
90
94
|
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
|
95
|
+
&\newline
|
|
96
|
+
&\hline \\[-1.ex]
|
|
91
97
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
92
|
-
|
|
98
|
+
&\newline
|
|
99
|
+
&\hline \\[-1.ex]
|
|
100
|
+
\end{array}
|
|
93
101
|
|
|
94
102
|
.. warning::
|
|
95
103
|
- This is an experimental optimizer API that is subject to change.
|
|
@@ -169,11 +177,10 @@ class AdamW(Optimizer):
|
|
|
169
177
|
self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
|
|
170
178
|
self.state_step = Parameter(Tensor([-1], mstype.float32), "state_step")
|
|
171
179
|
self.increase_tensor = Tensor(1, mstype.float32)
|
|
172
|
-
self.assignadd = P.AssignAdd()
|
|
173
180
|
self.adamw_opt = gen.AdamW()
|
|
174
181
|
|
|
175
182
|
def construct(self, gradients):
|
|
176
|
-
self.
|
|
183
|
+
self.state_step.add_(self.increase_tensor)
|
|
177
184
|
for group_id, group in enumerate(self.param_groups):
|
|
178
185
|
beta1, beta2 = group['betas']
|
|
179
186
|
maximize = group.get("maximize")
|