mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -125,18 +125,20 @@ class EmbeddingServiceOut:
|
|
|
125
125
|
|
|
126
126
|
class EmbeddingService:
|
|
127
127
|
r"""
|
|
128
|
-
|
|
128
|
+
ES(EmbeddingService) feature can support model training and inference
|
|
129
129
|
for PS embedding and data_parallel embedding, and provide unified embedding management, storage,
|
|
130
130
|
and computing capabilities for training and inference.
|
|
131
131
|
PS embedding refer to tables that vocab_size more than 100,000, and recommended to store them on the
|
|
132
132
|
Parameter Server (PS). Data_parallel embedding refer to tables that vocab_size less than 100,000, and recommended
|
|
133
133
|
to store them on device.
|
|
134
134
|
|
|
135
|
+
Currently, ES feature can only create one instance of EmbeddingService object.
|
|
136
|
+
|
|
135
137
|
.. warning::
|
|
136
138
|
This is an experimental EmbeddingService API that is subject to change.
|
|
137
139
|
|
|
138
140
|
.. note::
|
|
139
|
-
This API needs to call
|
|
141
|
+
This API needs to call :func:`mindspore.communication.init` before,
|
|
140
142
|
and it can take effect after the dynamic networking is completed.
|
|
141
143
|
|
|
142
144
|
Raises:
|
|
@@ -241,24 +243,26 @@ class EmbeddingService:
|
|
|
241
243
|
name (str): The embedding table name.
|
|
242
244
|
init_vocabulary_size (int): The size of embedding table.
|
|
243
245
|
embedding_dim (int): The embedding dim of data in embedding table.
|
|
244
|
-
max_feature_count (int): The count of keys when look up for PS.
|
|
245
|
-
initializer (Initializer): The initialization strategy for the PS embedding,
|
|
246
|
-
|
|
246
|
+
max_feature_count (int, optional): The count of keys when look up for PS. Default: ``None``.
|
|
247
|
+
initializer (Initializer, optional): The initialization strategy for the PS embedding,
|
|
248
|
+
default is ``Uniform(scale=0.01)``.
|
|
249
|
+
embedding_type (str, optional): The embedding type, configurable parameters ["PS", "data_parallel"],
|
|
247
250
|
``"PS"`` means initializing PS embedding, ``"data_parallel"`` means initializing data_parallel
|
|
248
251
|
embedding, and default is ``"PS"``.
|
|
249
|
-
ev_option (EmbeddingVariableOption): Properties of the PS embedding,
|
|
252
|
+
ev_option (EmbeddingVariableOption, optional): Properties of the PS embedding,
|
|
250
253
|
is a EmbeddingVariableOption obj which returned by embedding_variable_option function.
|
|
251
254
|
Default is ``None``.
|
|
252
|
-
multihot_lens (int): The param only use when allow_merge is enabled, and not support now.
|
|
255
|
+
multihot_lens (int, optional): The param only use when `allow_merge` is enabled, and not support now.
|
|
253
256
|
Default is ``None``.
|
|
254
|
-
optimizer (str): The type of optimizer in the train mode for PS embedding,
|
|
257
|
+
optimizer (str, optional): The type of optimizer in the train mode for PS embedding,
|
|
255
258
|
cannot be shared among each PS embedding, and currently only ``"Adam"``, ``"Ftrl"``, ``"SGD"`` and
|
|
256
259
|
``"RMSProp"`` are supported, and default is ``None``.
|
|
257
|
-
allow_merge (bool): Whether to enable merge data_parallel embeddings, currently only be False,
|
|
260
|
+
allow_merge (bool, optional): Whether to enable merge data_parallel embeddings, currently only be False,
|
|
258
261
|
and default is ``False``.
|
|
259
|
-
optimizer_param (float): The "initialize accumulator value" param
|
|
262
|
+
optimizer_param (float, optional): The "initialize accumulator value" param
|
|
263
|
+
of optimizer which configured by user,
|
|
260
264
|
representing the init value of moment accumulator, and default is ``None``.
|
|
261
|
-
mode (str): Run mode, configurable parameters ["train", "predict", "export"],
|
|
265
|
+
mode (str, optional): Run mode, configurable parameters ["train", "predict", "export"],
|
|
262
266
|
``"train"`` means train mode, ``"predict"`` means predict mode, ``"export"`` mean export mode,
|
|
263
267
|
and default is ``"train"``.
|
|
264
268
|
|
|
@@ -345,8 +349,9 @@ class EmbeddingService:
|
|
|
345
349
|
|
|
346
350
|
Args:
|
|
347
351
|
padding_key (int): The value for padding key, must be a genuine and legal hash key.
|
|
348
|
-
mask (bool): Whether to update padding key. If set to false, it will not be updated.
|
|
349
|
-
|
|
352
|
+
mask (bool, optional): Whether to update padding key. If set to false, it will not be updated.
|
|
353
|
+
Default is ``True``.
|
|
354
|
+
mask_zero (bool, optional): Whether to update padding key when key is 0. Default is ``False``.
|
|
350
355
|
|
|
351
356
|
Returns:
|
|
352
357
|
PaddingParamsOption object.
|
|
@@ -368,7 +373,7 @@ class EmbeddingService:
|
|
|
368
373
|
|
|
369
374
|
Args:
|
|
370
375
|
completion_key (int): The value for completion key.
|
|
371
|
-
mask (bool): Whether to update completion key. If set to false, it will not be updated,
|
|
376
|
+
mask (bool, optional): Whether to update completion key. If set to false, it will not be updated,
|
|
372
377
|
and default is ``True``.
|
|
373
378
|
|
|
374
379
|
Returns:
|
|
@@ -396,10 +401,11 @@ class EmbeddingService:
|
|
|
396
401
|
|
|
397
402
|
Args:
|
|
398
403
|
filter_freq (int): The frequency threshold value for feature admission.
|
|
399
|
-
default_key (int): The key that number of occurrences does not reach the threshold,
|
|
400
|
-
return value of
|
|
401
|
-
|
|
402
|
-
|
|
404
|
+
default_key (int, optional): The key that number of occurrences does not reach the threshold,
|
|
405
|
+
return value of `default_key` as the corresponding value when look up embedding,
|
|
406
|
+
and default is ``None``.
|
|
407
|
+
default_value (Union[int, float], optional): The key that number of occurrences does not
|
|
408
|
+
reach the threshold, return default value which length value is embedding dim, and default is ``None``.
|
|
403
409
|
|
|
404
410
|
Returns:
|
|
405
411
|
CounterFilter object.
|
|
@@ -460,16 +466,17 @@ class EmbeddingService:
|
|
|
460
466
|
Set variable option for PS embedding.
|
|
461
467
|
|
|
462
468
|
Args:
|
|
463
|
-
filter_option (CounterFilter): The option of counter filter. Default is ``None``.
|
|
464
|
-
padding_option (PaddingParamsOption): The option of padding key. Default is ``None``.
|
|
465
|
-
evict_option (EvictOption): The option evict. Default is ``None``.
|
|
466
|
-
completion_option (CompletionKeyOption): The option of completion key. Default is ``None``.
|
|
467
|
-
storage_option (None): Reserved option, currently not supported. Default is ``None``.
|
|
468
|
-
feature_freezing_option (None): Reserved option, currently not supported. Default is ``None``.
|
|
469
|
-
communication_option (None): Reserved option, currently not supported. Default is ``None``.
|
|
469
|
+
filter_option (CounterFilter, optional): The option of counter filter. Default is ``None``.
|
|
470
|
+
padding_option (PaddingParamsOption, optional): The option of padding key. Default is ``None``.
|
|
471
|
+
evict_option (EvictOption, optional): The option evict. Default is ``None``.
|
|
472
|
+
completion_option (CompletionKeyOption, optional): The option of completion key. Default is ``None``.
|
|
473
|
+
storage_option (None, optional): Reserved option, currently not supported. Default is ``None``.
|
|
474
|
+
feature_freezing_option (None, optional): Reserved option, currently not supported. Default is ``None``.
|
|
475
|
+
communication_option (None, optional): Reserved option, currently not supported. Default is ``None``.
|
|
470
476
|
|
|
471
477
|
Returns:
|
|
472
|
-
EmbeddingVariableOption object, used as the ev_option parameter for
|
|
478
|
+
EmbeddingVariableOption object, used as the ev_option parameter for
|
|
479
|
+
:func:`mindspore.experimental.es.EmbeddingService.embedding_init` .
|
|
473
480
|
|
|
474
481
|
Raises:
|
|
475
482
|
TypeError: If value of "filter_option" is not None and the type of "filter_option" is not CounterFilter.
|
|
@@ -501,7 +508,8 @@ class EmbeddingService:
|
|
|
501
508
|
|
|
502
509
|
.. note::
|
|
503
510
|
This function can only be executed by rank 0.
|
|
504
|
-
Need to call embedding_variable_option
|
|
511
|
+
Need to call :func:`mindspore.experimental.es.EmbeddingService.embedding_variable_option`
|
|
512
|
+
to set evict_option for each PS embedding before export.
|
|
505
513
|
|
|
506
514
|
Args:
|
|
507
515
|
file_path (str): The path to export embedding ckpt, and the last character cannot be ``"/"``.
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
from mindspore.experimental.llm_boost.atb import LlamaBoost, QwenBoost
|
|
19
|
+
from mindspore.experimental.llm_boost.ascend_native import *
|
|
19
20
|
from mindspore.experimental.llm_boost.register import LlmBoostRegister
|
|
20
21
|
|
|
21
22
|
__all__ = ["LlmBoostRegister"]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Provide llm boost for inference, such as LlamaBoost.
|
|
17
|
+
"""
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
|
|
20
|
+
from mindspore.experimental.llm_boost.ascend_native.llama_boost_ascend_native import LlamaBoostAscendNative
|
|
21
|
+
|
|
22
|
+
__all__ = ['LlamaBoostAscendNative']
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""AscendNative Llama Boost APIs."""
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import numpy as np
|
|
19
|
+
from mindspore.common import Tensor, dtype
|
|
20
|
+
from mindspore.experimental.llm_boost.ascend_native.llm_boost import LLMBoost
|
|
21
|
+
from mindspore.experimental.llm_boost.register import LlmBoostRegister, LlmBoostType
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def RoundUp(val: int, align: int) -> int:
|
|
25
|
+
if align == 0:
|
|
26
|
+
return 0
|
|
27
|
+
return -(val // -align) * align
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def ConvertTensor(nd_mat: np.ndarray, transpose: bool = True, nd2nz: bool = True) -> np.ndarray:
|
|
31
|
+
""" Transforms tensor format from Nd to Nz """
|
|
32
|
+
if transpose:
|
|
33
|
+
nd_mat = np.transpose(nd_mat)
|
|
34
|
+
if not nd2nz:
|
|
35
|
+
return nd_mat
|
|
36
|
+
block_size = (16, 16)
|
|
37
|
+
r = RoundUp(nd_mat.shape[0], block_size[0])
|
|
38
|
+
c = RoundUp(nd_mat.shape[1], block_size[1])
|
|
39
|
+
r_pad = r - nd_mat.shape[0]
|
|
40
|
+
c_pad = c - nd_mat.shape[1]
|
|
41
|
+
nd_mat = np.pad(nd_mat, ((0, r_pad), (0, c_pad)))
|
|
42
|
+
nz_mat = np.transpose(np.reshape(
|
|
43
|
+
nd_mat, (r, c // block_size[1], block_size[1])), (1, 0, 2))
|
|
44
|
+
nz_mat = nz_mat.reshape(r, c)
|
|
45
|
+
return nz_mat
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@LlmBoostRegister.register(LlmBoostType.ASCEND_NATIVE, "Llama")
|
|
49
|
+
class LlamaBoostAscendNative(LLMBoost):
|
|
50
|
+
r"""
|
|
51
|
+
Implements an Llama model in a single kernel.
|
|
52
|
+
it forwards the python functions to the C++ binded object
|
|
53
|
+
"""
|
|
54
|
+
def _get_from_dict(self, dictionary, name):
|
|
55
|
+
""" internal function to get a specific tensor from the dictionary """
|
|
56
|
+
all_relevant_layers = [value for key, value in dictionary.items() if name in key]
|
|
57
|
+
if all_relevant_layers:
|
|
58
|
+
return all_relevant_layers[0].asnumpy()
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
def _get_quant_triplet_from_dict(self, dictionary, name):
|
|
62
|
+
""" internal function to get a weight triple tensor from the dictionary """
|
|
63
|
+
weights = self._get_from_dict(dictionary, name + "._handler.weight")
|
|
64
|
+
scale = self._get_from_dict(dictionary, name + "._weight_quantizer.scale")
|
|
65
|
+
offset = self._get_from_dict(dictionary, name + "._weight_quantizer.zp_neg")
|
|
66
|
+
return weights, scale, offset
|
|
67
|
+
|
|
68
|
+
def _prepare_single_layer(self, ckpt, config, id):
|
|
69
|
+
""" prepares the dictionary of weights of a single layer """
|
|
70
|
+
prefix = 'model.layers.' + str(id)
|
|
71
|
+
is_last = (id == config.num_layers-1)
|
|
72
|
+
layer = 'layers.' + str(id) + '.'
|
|
73
|
+
l_dict = {key: value for key, value in ckpt.items() if layer in key}
|
|
74
|
+
if config.n_kv_heads is None:
|
|
75
|
+
config.n_kv_heads = config.num_heads
|
|
76
|
+
start = 0
|
|
77
|
+
end = config.hidden_size
|
|
78
|
+
kv_start = 0
|
|
79
|
+
kv_end = int(config.hidden_size*config.n_kv_heads/config.num_heads)
|
|
80
|
+
ffn_hid = [value for key, value in l_dict.items() if "w3" in key][0].shape[0]
|
|
81
|
+
ffn_start = 0
|
|
82
|
+
ffn_end = ffn_hid
|
|
83
|
+
rank_size = int(os.getenv('RANK_SIZE', '1'))
|
|
84
|
+
#Emir if (config.parallel_mode != 2): # 2 - AUTO_PARALLEL
|
|
85
|
+
hid_size = end
|
|
86
|
+
kv_hid_size = kv_end
|
|
87
|
+
embed_size = config.vocab_size
|
|
88
|
+
rank_id = int(os.getenv('RANK_ID', '0'))
|
|
89
|
+
if (hid_size % rank_size == 0) and (ffn_hid % rank_size == 0) and (embed_size % rank_size == 0):
|
|
90
|
+
start = int(rank_id * hid_size / rank_size)
|
|
91
|
+
end = int((rank_id + 1) * hid_size / rank_size)
|
|
92
|
+
kv_start = int(rank_id * kv_hid_size / rank_size)
|
|
93
|
+
kv_end = int((rank_id + 1) * kv_hid_size / rank_size)
|
|
94
|
+
ffn_start = int(rank_id * ffn_hid / rank_size)
|
|
95
|
+
ffn_end = int((rank_id + 1) * ffn_hid / rank_size)
|
|
96
|
+
else:
|
|
97
|
+
raise RuntimeError("hidden size and ffn hidden size must be divided by rank size without remainder. \
|
|
98
|
+
hidden_size: ", hid_size, " ffn_hidden_size: ", ffn_hid, " rank_size: ", rank_size)
|
|
99
|
+
quant = (self._get_from_dict(l_dict, "_weight_quantizer") is not None)
|
|
100
|
+
unite_qkv = (config.num_heads == config.n_kv_heads)
|
|
101
|
+
self.dictionary[prefix + ".attention_norm.weight"] = \
|
|
102
|
+
Tensor(self._get_from_dict(l_dict, "attention_norm"), dtype=dtype.float16)
|
|
103
|
+
self.dictionary[prefix + ".ffn_norm.weight"] = \
|
|
104
|
+
Tensor(self._get_from_dict(l_dict, "ffn_norm"), dtype=dtype.float16)
|
|
105
|
+
if is_last:
|
|
106
|
+
self.dictionary['lm_head.weight'] = Tensor(ConvertTensor(ckpt['lm_head.weight'].asnumpy()[:, start:end]))
|
|
107
|
+
|
|
108
|
+
if not quant:
|
|
109
|
+
self._pack_attn_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
|
|
110
|
+
self._pack_ffn_weights(l_dict, prefix, ffn_start, ffn_end)
|
|
111
|
+
else:
|
|
112
|
+
self._pack_attn_quant_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
|
|
113
|
+
self._pack_ffn_quant_weights(l_dict, prefix, ffn_start, ffn_end)
|
|
114
|
+
|
|
115
|
+
def _pack_attn_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
|
|
116
|
+
""" prepares the dictionary of weights of an attention block """
|
|
117
|
+
wq = self._get_from_dict(l_dict, "wq")[start:end, :]
|
|
118
|
+
wk = self._get_from_dict(l_dict, "wk")[kv_start:kv_end, :]
|
|
119
|
+
wv = self._get_from_dict(l_dict, "wv")[kv_start:kv_end, :]
|
|
120
|
+
self.dictionary[prefix + ".attention.wo.weight"] = \
|
|
121
|
+
Tensor(ConvertTensor(self._get_from_dict(l_dict, "wo")[:, start:end]))
|
|
122
|
+
if unite_qkv:
|
|
123
|
+
self.dictionary[prefix + ".attention.wqkv.weight"] = Tensor(ConvertTensor(np.concatenate((wq, wk, wv))))
|
|
124
|
+
else:
|
|
125
|
+
self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq))
|
|
126
|
+
self.dictionary[prefix + ".attention.wkv.weight"] = Tensor(ConvertTensor(np.concatenate((wk, wv))))
|
|
127
|
+
|
|
128
|
+
def _pack_ffn_weights(self, l_dict, prefix, ffn_start, ffn_end):
|
|
129
|
+
""" prepares the dictionary of weights of an ffn block """
|
|
130
|
+
self.dictionary[prefix + ".feed_forward.w2.weight"] = \
|
|
131
|
+
Tensor(ConvertTensor(self._get_from_dict(l_dict, "w2")[:, ffn_start:ffn_end]))
|
|
132
|
+
w1 = self._get_from_dict(l_dict, "w1")[ffn_start:ffn_end, :]
|
|
133
|
+
w3 = self._get_from_dict(l_dict, "w3")[ffn_start:ffn_end, :]
|
|
134
|
+
self.dictionary[prefix + ".feed_forward.w13.weight"] = Tensor(ConvertTensor(np.concatenate((w1, w3))))
|
|
135
|
+
|
|
136
|
+
def _pack_attn_quant_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
|
|
137
|
+
""" prepares the dictionary of weights of a quantized attention block """
|
|
138
|
+
wq, wq_scale, wq_offset = self._get_quant_triplet_from_dict(l_dict, "wq")
|
|
139
|
+
wk, wk_scale, wk_offset = self._get_quant_triplet_from_dict(l_dict, "wk")
|
|
140
|
+
wv, wv_scale, wv_offset = self._get_quant_triplet_from_dict(l_dict, "wv")
|
|
141
|
+
wo, wo_scale, wo_offset = self._get_quant_triplet_from_dict(l_dict, "wo")
|
|
142
|
+
self.dictionary[prefix + ".attention.wo.weight"] = Tensor(ConvertTensor(wo[:, start:end], nd2nz=False))
|
|
143
|
+
self.dictionary[prefix + ".attention.wo.weight.scale"] = Tensor(wo_scale[start:end])
|
|
144
|
+
self.dictionary[prefix + ".attention.wo.weight.offset"] = Tensor(wo_offset[start:end])
|
|
145
|
+
|
|
146
|
+
if unite_qkv:
|
|
147
|
+
self.dictionary[prefix + ".attention.wqkv.weight"] = \
|
|
148
|
+
Tensor(ConvertTensor(np.concatenate((wq[start:end, :], wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])),
|
|
149
|
+
nd2nz=False))
|
|
150
|
+
self.dictionary[prefix + ".attention.wqkv.weight.scale"] = \
|
|
151
|
+
Tensor(np.concatenate((wq_scale[start:end], wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
|
|
152
|
+
self.dictionary[prefix + ".attention.wqkv.weight.offset"] = \
|
|
153
|
+
Tensor(np.concatenate((wq_offset[start:end], wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
|
|
154
|
+
else:
|
|
155
|
+
self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq[start:end, :], nd2nz=False))
|
|
156
|
+
self.dictionary[prefix + ".attention.wq.weight.scale"] = Tensor(wq_scale[start:end])
|
|
157
|
+
self.dictionary[prefix + ".attention.wq.weight.offset"] = Tensor(wq_offset[start:end])
|
|
158
|
+
self.dictionary[prefix + ".attention.wkv.weight"] = \
|
|
159
|
+
Tensor(ConvertTensor(np.concatenate((wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])), nd2nz=False))
|
|
160
|
+
self.dictionary[prefix + ".attention.wkv.weight.scale"] = \
|
|
161
|
+
Tensor(np.concatenate((wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
|
|
162
|
+
self.dictionary[prefix + ".attention.wkv.weight.offset"] = \
|
|
163
|
+
Tensor(np.concatenate((wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
|
|
164
|
+
|
|
165
|
+
def _pack_ffn_quant_weights(self, l_dict, prefix, ffn_start, ffn_end):
|
|
166
|
+
""" prepares the dictionary of weights of a quantized ffn block """
|
|
167
|
+
w1, w1_scale, w1_offset = self._get_quant_triplet_from_dict(l_dict, "w1")
|
|
168
|
+
w2, w2_scale, w2_offset = self._get_quant_triplet_from_dict(l_dict, "w2")
|
|
169
|
+
w3, w3_scale, w3_offset = self._get_quant_triplet_from_dict(l_dict, "w3")
|
|
170
|
+
self.dictionary[prefix + ".feed_forward.w2.weight"] = Tensor(ConvertTensor(w2[:, ffn_start:ffn_end],
|
|
171
|
+
nd2nz=False))
|
|
172
|
+
self.dictionary[prefix + ".feed_forward.w2.weight.scale"] = Tensor(w2_scale[ffn_start:ffn_end])
|
|
173
|
+
self.dictionary[prefix + ".feed_forward.w2.weight.offset"] = Tensor(w2_offset[ffn_start:ffn_end])
|
|
174
|
+
|
|
175
|
+
self.dictionary[prefix + ".feed_forward.w13.weight"] = \
|
|
176
|
+
Tensor(ConvertTensor(np.concatenate((w1[ffn_start:ffn_end, :], w3[ffn_start:ffn_end, :])), nd2nz=False))
|
|
177
|
+
self.dictionary[prefix + ".feed_forward.w13.weight.scale"] = \
|
|
178
|
+
Tensor(np.concatenate((w1_scale[ffn_start:ffn_end], w3_scale[ffn_start:ffn_end])))
|
|
179
|
+
self.dictionary[prefix + ".feed_forward.w13.weight.offset"] = \
|
|
180
|
+
Tensor(np.concatenate((w1_offset[ffn_start:ffn_end], w3_offset[ffn_start:ffn_end])))
|
|
181
|
+
|
|
182
|
+
def _prepare_cos_sin_arrays(self, config, theta=10000):
|
|
183
|
+
""" prepares the cosine and sine arrays """
|
|
184
|
+
head_dim = config.hidden_size // config.num_heads
|
|
185
|
+
max_position_embedding = \
|
|
186
|
+
config.max_position_embedding if config.max_position_embedding is not None else config.seq_length
|
|
187
|
+
freqs_base = np.arange(0, head_dim, 2)[: (head_dim // 2)].astype(np.float32)
|
|
188
|
+
freqs = 1.0 / (theta ** (freqs_base / head_dim))
|
|
189
|
+
t = np.arange(0, max_position_embedding, 1).astype(np.float32)
|
|
190
|
+
freqs = np.outer(t, freqs)
|
|
191
|
+
emb = np.concatenate((freqs, freqs), axis=-1)
|
|
192
|
+
freqs_cos = Tensor(np.cos(emb), dtype=dtype.float16)
|
|
193
|
+
sin = np.sin(emb)
|
|
194
|
+
|
|
195
|
+
sin[:, :int(emb.shape[1]/2)] = -sin[:, :int(emb.shape[1]/2)]
|
|
196
|
+
self.dictionary['model.cos.weight'] = freqs_cos
|
|
197
|
+
freqs_sin = Tensor(sin, dtype=dtype.float16)
|
|
198
|
+
self.dictionary['model.sin.weight'] = freqs_sin
|
|
199
|
+
|
|
200
|
+
def set_weights(self, ckpt_dict):
|
|
201
|
+
""" load the checkpoint """
|
|
202
|
+
self.dictionary = {}
|
|
203
|
+
self.dictionary['model.tok_embeddings.embedding_weight'] = \
|
|
204
|
+
Tensor(ckpt_dict['model.tok_embeddings.embedding_weight'].asnumpy())
|
|
205
|
+
self.dictionary['model.norm_out.weight'] = \
|
|
206
|
+
Tensor(ckpt_dict['model.norm_out.weight'].asnumpy(), dtype=dtype.float16)
|
|
207
|
+
self._prepare_cos_sin_arrays(self.config)
|
|
208
|
+
for layer_id in range(self.config.num_layers):
|
|
209
|
+
self._prepare_single_layer(ckpt_dict, self.config, layer_id)
|
|
210
|
+
|
|
211
|
+
self.binder.set_weights_map(self.dictionary)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""LLMBoost APIs."""
|
|
16
|
+
|
|
17
|
+
from mindspore.common import Tensor
|
|
18
|
+
|
|
19
|
+
class LLMBoost():
|
|
20
|
+
r"""
|
|
21
|
+
Implements an LLM in a single kernel.
|
|
22
|
+
it forwards the python function to the C++ binded object
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, config):
|
|
25
|
+
r"""
|
|
26
|
+
initialize the parameters of the llm binder.
|
|
27
|
+
config is simply the config object of the model
|
|
28
|
+
"""
|
|
29
|
+
from mindspore._c_expression import LlmBoostBinder
|
|
30
|
+
self.config = config
|
|
31
|
+
self.binder = LlmBoostBinder("AscendNative", config.model_type)
|
|
32
|
+
self.binder.init_model(config.to_dict())
|
|
33
|
+
|
|
34
|
+
def init(self):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the object
|
|
37
|
+
returns True if object needs input manipulation by mindformers
|
|
38
|
+
"""
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
def set_kvcache(self, k_caches=None, v_caches=None):
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
def forward(self, input_ids, batch_valid_length, position_ids=None):
|
|
45
|
+
ret = self.binder.forward([input_ids, batch_valid_length], "nothing really")
|
|
46
|
+
return Tensor(ret[0])
|
|
47
|
+
|
|
48
|
+
def set_weights(self, ckpt_dict):
|
|
49
|
+
self.binder.set_weights_map(ckpt_dict)
|
|
50
|
+
|
|
51
|
+
def add_flags(self, is_first_iteration=False):
|
|
52
|
+
self.binder.add_flags(is_first_iteration=is_first_iteration)
|
|
@@ -112,8 +112,7 @@ class AtbBoostBase:
|
|
|
112
112
|
|
|
113
113
|
def _convert_qkv_concat_weight(self, param_dict):
|
|
114
114
|
"""convert qkv concat weight"""
|
|
115
|
-
|
|
116
|
-
for i in range(assume_num_layers):
|
|
115
|
+
for i in range(self.num_layers):
|
|
117
116
|
# qkv weight concat
|
|
118
117
|
wq_weight_name = f"model.layers.{i}.attention.wq.weight"
|
|
119
118
|
wk_weight_name = f"model.layers.{i}.attention.wk.weight"
|
|
@@ -151,7 +150,7 @@ class AtbBoostBase:
|
|
|
151
150
|
logger.info(f"transform: {qkv_concat_weight_name}")
|
|
152
151
|
logger.info(f"transform: {gate_hidden_concat_weight_name}")
|
|
153
152
|
|
|
154
|
-
for i in range(
|
|
153
|
+
for i in range(self.num_layers):
|
|
155
154
|
# qkv bias concat
|
|
156
155
|
wq_bias_name = f"model.layers.{i}.attention.wq.bias"
|
|
157
156
|
wk_bias_name = f"model.layers.{i}.attention.wk.bias"
|
|
@@ -43,7 +43,11 @@ class LlamaBoost(AtbBoostBase):
|
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
def init(self):
|
|
46
|
-
"""
|
|
46
|
+
"""
|
|
47
|
+
Initialize the object
|
|
48
|
+
returns True if object needs input manipulation by mindformers
|
|
49
|
+
"""
|
|
50
|
+
|
|
47
51
|
coder_param = {
|
|
48
52
|
"normEps": self.config.rms_norm_eps,
|
|
49
53
|
"normType": NormType.RMS_NORM,
|
|
@@ -93,6 +97,7 @@ class LlamaBoost(AtbBoostBase):
|
|
|
93
97
|
}
|
|
94
98
|
self.atb_encoder_operation.init(json.dumps({**encoder_param}))
|
|
95
99
|
self.atb_decoder_operation.init(json.dumps({**decoder_param}))
|
|
100
|
+
return True
|
|
96
101
|
|
|
97
102
|
def _prepare_inputs(
|
|
98
103
|
self,
|
|
@@ -23,7 +23,7 @@ from copy import copy
|
|
|
23
23
|
import numbers
|
|
24
24
|
import mindspore as ms
|
|
25
25
|
from mindspore.common.parameter import Parameter, _get_unique_parameter_key
|
|
26
|
-
from mindspore._c_expression import
|
|
26
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
27
27
|
from mindspore._c_expression import MapTensor_
|
|
28
28
|
from mindspore.ops.operations import _map_tensor_ops
|
|
29
29
|
|
|
@@ -78,12 +78,12 @@ class MapParameter(Parameter):
|
|
|
78
78
|
if value_dtype is not None:
|
|
79
79
|
if isinstance(value_shape, numbers.Number):
|
|
80
80
|
value_shape = (value_shape,)
|
|
81
|
-
data = Tensor_(value_dtype, value_shape)
|
|
81
|
+
data = Tensor_(dtype=value_dtype, shape=value_shape)
|
|
82
82
|
elif value_tensor is not None:
|
|
83
|
-
data = Tensor_(value_tensor.dtype, value_tensor.shape)
|
|
83
|
+
data = Tensor_(dtype=value_tensor.dtype, shape=value_tensor.shape)
|
|
84
84
|
else:
|
|
85
85
|
# default
|
|
86
|
-
data = Tensor_(ms.float32, (1,))
|
|
86
|
+
data = Tensor_(dtype=ms.float32, shape=(1,))
|
|
87
87
|
obj = Tensor_.__new__(cls)
|
|
88
88
|
Tensor_.__init__(obj, data)
|
|
89
89
|
# Compatible attributes with Parameter.
|
|
@@ -37,14 +37,14 @@ class Adadelta(Optimizer):
|
|
|
37
37
|
Implements Adadelta algorithm.
|
|
38
38
|
|
|
39
39
|
.. math::
|
|
40
|
-
|
|
41
|
-
&\rule{
|
|
40
|
+
\begin{aligned}
|
|
41
|
+
&\rule{180mm}{0.4pt} \\
|
|
42
42
|
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
|
|
43
43
|
\: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
|
|
44
44
|
\: \lambda \text{ (weight decay)} \\
|
|
45
45
|
&\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)},
|
|
46
46
|
\: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex]
|
|
47
|
-
&\rule{
|
|
47
|
+
&\rule{180mm}{0.4pt} \\
|
|
48
48
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
49
49
|
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
50
50
|
&\hspace{5mm}if \: \lambda \neq 0 \\
|
|
@@ -55,10 +55,10 @@ class Adadelta(Optimizer):
|
|
|
55
55
|
&\hspace{5mm} u_t \leftarrow u_{t-1} \rho +
|
|
56
56
|
\Delta x^2_t (1 - \rho) \\
|
|
57
57
|
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\
|
|
58
|
-
&\rule{
|
|
58
|
+
&\rule{180mm}{0.4pt} \\[-1.ex]
|
|
59
59
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
60
|
-
&\rule{
|
|
61
|
-
|
|
60
|
+
&\rule{180mm}{0.4pt} \\[-1.ex]
|
|
61
|
+
\end{aligned}
|
|
62
62
|
|
|
63
63
|
.. warning::
|
|
64
64
|
This is an experimental optimizer API that is subject to change.
|
|
@@ -38,12 +38,12 @@ class Adagrad(Optimizer):
|
|
|
38
38
|
|
|
39
39
|
.. math::
|
|
40
40
|
\begin{aligned}
|
|
41
|
-
&\rule{
|
|
41
|
+
&\rule{160mm}{0.4pt} \\
|
|
42
42
|
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
|
|
43
43
|
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
|
|
44
44
|
&\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
|
|
45
45
|
&\textbf{initialize} : state\_sum_0 \leftarrow 0 \\[-1.ex]
|
|
46
|
-
&\rule{
|
|
46
|
+
&\rule{160mm}{0.4pt} \\
|
|
47
47
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
48
48
|
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
49
49
|
&\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
|
|
@@ -52,9 +52,9 @@ class Adagrad(Optimizer):
|
|
|
52
52
|
&\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
|
|
53
53
|
&\hspace{5mm}\theta_t \leftarrow
|
|
54
54
|
\theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
|
|
55
|
-
&\rule{
|
|
55
|
+
&\rule{160mm}{0.4pt} \\[-1.ex]
|
|
56
56
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
57
|
-
&\rule{
|
|
57
|
+
&\rule{160mm}{0.4pt} \\[-1.ex]
|
|
58
58
|
\end{aligned}
|
|
59
59
|
|
|
60
60
|
.. warning::
|
|
@@ -49,12 +49,14 @@ class Adam(Optimizer):
|
|
|
49
49
|
|
|
50
50
|
.. math::
|
|
51
51
|
\begin{aligned}
|
|
52
|
+
&\rule{180mm}{0.4pt} \\
|
|
52
53
|
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
|
53
54
|
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
|
|
54
55
|
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
|
|
55
56
|
\:\textit{maximize} \\
|
|
56
57
|
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
|
57
58
|
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
|
|
59
|
+
&\rule{180mm}{0.4pt} \\
|
|
58
60
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
59
61
|
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
|
60
62
|
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
@@ -74,10 +76,15 @@ class Adam(Optimizer):
|
|
|
74
76
|
&\hspace{5mm}\textbf{else} \\
|
|
75
77
|
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
|
|
76
78
|
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
|
79
|
+
&\rule{180mm}{0.4pt} \\[-1.ex]
|
|
77
80
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
81
|
+
&\rule{180mm}{0.4pt} \\[-1.ex]
|
|
78
82
|
\end{aligned}
|
|
79
83
|
|
|
80
84
|
.. warning::
|
|
85
|
+
The implementation formula of this optimizer interface is not completely consistent with that in the paper.
|
|
86
|
+
If you want to use an interface that is completely consistent, it is recommended to use
|
|
87
|
+
:class:`mindspore.mint.optim.Adam`, which currently only supports Ascend.
|
|
81
88
|
This is an experimental optimizer API that is subject to change.
|
|
82
89
|
This module must be used with lr scheduler module in `LRScheduler Class
|
|
83
90
|
<https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_ .
|
|
@@ -43,14 +43,14 @@ class Adamax(Optimizer):
|
|
|
43
43
|
|
|
44
44
|
.. math::
|
|
45
45
|
\begin{aligned}
|
|
46
|
-
&\rule{
|
|
46
|
+
&\rule{180mm}{0.4pt} \\
|
|
47
47
|
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
|
48
48
|
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
|
|
49
49
|
\: \lambda \text{ (weight decay)}, \\
|
|
50
50
|
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
|
|
51
51
|
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
|
52
52
|
u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
|
|
53
|
-
&\rule{
|
|
53
|
+
&\rule{180mm}{0.4pt} \\
|
|
54
54
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
55
55
|
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
56
56
|
&\hspace{5mm}if \: \lambda \neq 0 \\
|
|
@@ -58,9 +58,9 @@ class Adamax(Optimizer):
|
|
|
58
58
|
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
|
59
59
|
&\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
|
|
60
60
|
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
|
|
61
|
-
&\rule{
|
|
61
|
+
&\rule{180mm}{0.4pt} \\[-1.ex]
|
|
62
62
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
63
|
-
&\rule{
|
|
63
|
+
&\rule{180mm}{0.4pt} \\[-1.ex]
|
|
64
64
|
\end{aligned}
|
|
65
65
|
|
|
66
66
|
.. warning::
|